cui-llama.rn 1.4.2 → 1.4.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/README.md +93 -114
  2. package/android/src/main/CMakeLists.txt +5 -0
  3. package/android/src/main/build-arm64/CMakeCache.txt +429 -0
  4. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +81 -0
  5. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +101 -0
  6. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
  7. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
  8. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeSystem.cmake +15 -0
  9. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.c +904 -0
  10. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdC/CMakeCCompilerId.o +0 -0
  11. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.cpp +919 -0
  12. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CompilerIdCXX/CMakeCXXCompilerId.o +0 -0
  13. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +431 -0
  14. package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +16 -0
  15. package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +165 -0
  16. package/android/src/main/build-arm64/CMakeFiles/Makefile2 +297 -0
  17. package/android/src/main/build-arm64/CMakeFiles/Progress/1 +1 -0
  18. package/android/src/main/build-arm64/CMakeFiles/Progress/2 +1 -0
  19. package/android/src/main/build-arm64/CMakeFiles/Progress/3 +1 -0
  20. package/android/src/main/build-arm64/CMakeFiles/Progress/4 +1 -0
  21. package/android/src/main/build-arm64/CMakeFiles/Progress/5 +1 -0
  22. package/android/src/main/build-arm64/CMakeFiles/Progress/6 +1 -0
  23. package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +1 -0
  24. package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +8 -0
  25. package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +1 -0
  26. package/android/src/main/build-arm64/CMakeFiles/progress.marks +1 -0
  27. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
  28. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +58 -0
  29. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
  30. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +756 -0
  31. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
  32. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +709 -0
  33. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
  34. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +714 -0
  35. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
  36. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +62 -0
  37. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
  38. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +708 -0
  39. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
  40. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +113 -0
  41. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
  42. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +713 -0
  43. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
  44. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +763 -0
  45. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
  46. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +61 -0
  47. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
  48. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +707 -0
  49. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
  50. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +104 -0
  51. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
  52. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +714 -0
  53. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
  54. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +723 -0
  55. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +62 -0
  56. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +722 -0
  57. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +89 -0
  58. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +2 -0
  59. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +2 -0
  60. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +2 -0
  61. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +17 -0
  62. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +41 -0
  63. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +62 -0
  64. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +722 -0
  65. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +89 -0
  66. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +2 -0
  67. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +2 -0
  68. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +2 -0
  69. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +17 -0
  70. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +41 -0
  71. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +62 -0
  72. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +722 -0
  73. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +89 -0
  74. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +2 -0
  75. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +2 -0
  76. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +2 -0
  77. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +17 -0
  78. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +41 -0
  79. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +62 -0
  80. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +722 -0
  81. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +89 -0
  82. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +2 -0
  83. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +2 -0
  84. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +2 -0
  85. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +17 -0
  86. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +41 -0
  87. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +62 -0
  88. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +722 -0
  89. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +89 -0
  90. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +2 -0
  91. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +2 -0
  92. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +2 -0
  93. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +17 -0
  94. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +41 -0
  95. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +62 -0
  96. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +722 -0
  97. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +89 -0
  98. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +2 -0
  99. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +2 -0
  100. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +2 -0
  101. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +17 -0
  102. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +41 -0
  103. package/android/src/main/build-arm64/Makefile +1862 -0
  104. package/android/src/main/build-arm64/cmake_install.cmake +66 -0
  105. package/android/src/main/java/com/rnllama/LlamaContext.java +92 -18
  106. package/android/src/main/java/com/rnllama/RNLlama.java +37 -4
  107. package/android/src/main/jni-utils.h +6 -0
  108. package/android/src/main/jni.cpp +287 -31
  109. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  110. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  111. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  112. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  113. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  114. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  115. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  116. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  117. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +7 -2
  118. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +7 -2
  119. package/cpp/chat-template.hpp +529 -0
  120. package/cpp/chat.cpp +1085 -0
  121. package/cpp/chat.hpp +55 -0
  122. package/cpp/common.cpp +159 -36
  123. package/cpp/common.h +64 -19
  124. package/cpp/ggml-alloc.c +1 -13
  125. package/cpp/ggml-common.h +0 -2
  126. package/cpp/ggml-cpu-impl.h +6 -12
  127. package/cpp/ggml-cpu-quants.c +937 -340
  128. package/cpp/ggml-cpu.c +207 -113
  129. package/cpp/ggml-cpu.cpp +4 -6
  130. package/cpp/ggml-cpu.h +1 -1
  131. package/cpp/ggml-metal.h +66 -66
  132. package/cpp/ggml-metal.m +141 -23
  133. package/cpp/ggml.c +24 -14
  134. package/cpp/ggml.h +2 -2
  135. package/cpp/json-schema-to-grammar.cpp +46 -66
  136. package/cpp/json-schema-to-grammar.h +15 -1
  137. package/cpp/llama-arch.cpp +7 -2
  138. package/cpp/llama-arch.h +3 -1
  139. package/cpp/llama-chat.cpp +10 -1
  140. package/cpp/llama-chat.h +1 -0
  141. package/cpp/llama-grammar.cpp +86 -6
  142. package/cpp/llama-grammar.h +22 -1
  143. package/cpp/llama-impl.h +6 -6
  144. package/cpp/llama-kv-cache.h +1 -1
  145. package/cpp/llama-mmap.h +1 -0
  146. package/cpp/llama-model-loader.cpp +1 -1
  147. package/cpp/llama-model.cpp +32 -6
  148. package/cpp/llama-sampling.cpp +178 -61
  149. package/cpp/llama-vocab.cpp +8 -3
  150. package/cpp/llama.cpp +188 -128
  151. package/cpp/llama.h +27 -10
  152. package/cpp/log.cpp +32 -10
  153. package/cpp/log.h +12 -1
  154. package/cpp/minja.hpp +2883 -0
  155. package/cpp/rn-llama.cpp +82 -5
  156. package/cpp/rn-llama.h +16 -1
  157. package/cpp/sampling.cpp +68 -41
  158. package/cpp/sampling.h +3 -0
  159. package/cpp/sgemm.cpp +9 -8
  160. package/cpp/unicode.cpp +9 -2
  161. package/ios/CMakeLists.txt +6 -0
  162. package/ios/RNLlama.h +0 -8
  163. package/ios/RNLlama.mm +27 -3
  164. package/ios/RNLlamaContext.h +10 -1
  165. package/ios/RNLlamaContext.mm +269 -57
  166. package/jest/mock.js +21 -2
  167. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  168. package/lib/commonjs/grammar.js +3 -0
  169. package/lib/commonjs/grammar.js.map +1 -1
  170. package/lib/commonjs/index.js +87 -13
  171. package/lib/commonjs/index.js.map +1 -1
  172. package/lib/module/NativeRNLlama.js.map +1 -1
  173. package/lib/module/grammar.js +3 -0
  174. package/lib/module/grammar.js.map +1 -1
  175. package/lib/module/index.js +86 -13
  176. package/lib/module/index.js.map +1 -1
  177. package/lib/typescript/NativeRNLlama.d.ts +107 -2
  178. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  179. package/lib/typescript/grammar.d.ts.map +1 -1
  180. package/lib/typescript/index.d.ts +32 -7
  181. package/lib/typescript/index.d.ts.map +1 -1
  182. package/llama-rn.podspec +1 -1
  183. package/package.json +2 -2
  184. package/src/NativeRNLlama.ts +115 -3
  185. package/src/grammar.ts +3 -0
  186. package/src/index.ts +138 -21
package/cpp/ggml-cpu.c CHANGED
@@ -7,7 +7,6 @@
7
7
  #include "ggml-cpu-impl.h"
8
8
  #include "ggml-cpu.h"
9
9
  #include "ggml-impl.h"
10
- #include "ggml-quants.h"
11
10
  #include "ggml-cpu-quants.h"
12
11
  #include "ggml-threading.h"
13
12
  #include "ggml.h"
@@ -1077,29 +1076,23 @@ do { \
1077
1076
  #define LM_GGML_F16_STEP 32
1078
1077
  #define LM_GGML_F16_EPR 8
1079
1078
 
1080
- // F16 arithmetic is not supported by AVX, so we use F32 instead
1079
+ // F16 arithmetic is not supported by LASX, so we use F32 instead
1081
1080
 
1082
1081
  #define LM_GGML_F32Cx8 __m256
1083
1082
  #define LM_GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
1084
1083
  #define LM_GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
1085
1084
 
1086
1085
  static inline __m256 __lasx_f32cx8_load(const lm_ggml_fp16_t * x) {
1087
- float tmp[8];
1088
-
1089
- for (int i = 0; i < 8; i++) {
1090
- tmp[i] = LM_GGML_FP16_TO_FP32(x[i]);
1091
- }
1092
-
1093
- return (__m256)__lasx_xvld(tmp, 0);
1086
+ __m256i a;
1087
+ memcpy(&a, x, sizeof(lm_ggml_fp16_t) * 8);
1088
+ a = __lasx_xvpermi_d(a, 0 | (1 << 4));
1089
+ return __lasx_xvfcvtl_s_h(a);
1094
1090
  }
1095
- static inline void __lasx_f32cx8_store(lm_ggml_fp16_t * x, __m256 y) {
1096
- float arr[8];
1097
1091
 
1098
- __lasx_xvst(y, arr, 0);
1099
-
1100
- for (int i = 0; i < 8; i++) {
1101
- x[i] = LM_GGML_FP32_TO_FP16(arr[i]);
1102
- }
1092
+ static inline void __lasx_f32cx8_store(lm_ggml_fp16_t * x, __m256 y) {
1093
+ __m256i a = __lasx_xvfcvt_h_s(y, y);
1094
+ a = __lasx_xvpermi_d(a, 0 | (2 << 2));
1095
+ memcpy(x, &a, sizeof(lm_ggml_fp16_t) * 8);
1103
1096
  }
1104
1097
  #define LM_GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
1105
1098
  #define LM_GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
@@ -1296,12 +1289,12 @@ struct lm_ggml_threadpool {
1296
1289
  atomic_int n_graph; // incremented when there is work to be done (i.e each graph)
1297
1290
  atomic_int LM_GGML_CACHE_ALIGN n_barrier;
1298
1291
  atomic_int LM_GGML_CACHE_ALIGN n_barrier_passed;
1299
- atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1292
+ atomic_int LM_GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
1300
1293
 
1301
1294
  // these are atomic as an annotation for thread-sanitizer
1302
1295
  atomic_bool stop; // Used for stopping the threadpool altogether
1303
1296
  atomic_bool pause; // Used for pausing the threadpool or individual threads
1304
- atomic_bool abort; // Used for aborting processing of a graph
1297
+ atomic_int abort; // Used for aborting processing of a graph
1305
1298
 
1306
1299
  struct lm_ggml_compute_state * workers; // per thread state
1307
1300
  int n_threads_max; // number of threads in the pool
@@ -1823,7 +1816,7 @@ inline static float lm_ggml_silu_f32(float x) {
1823
1816
 
1824
1817
  #if __FINITE_MATH_ONLY__
1825
1818
  #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
1826
- #error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461"
1819
+ #error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
1827
1820
  #endif
1828
1821
 
1829
1822
  #if defined(__ARM_NEON) && defined(__aarch64__)
@@ -7495,6 +7488,7 @@ UseGgmlGemm1:;
7495
7488
  if (src1->type != vec_dot_type) {
7496
7489
  char * wdata = params->wdata;
7497
7490
 
7491
+ const size_t nbw0 = lm_ggml_type_size(vec_dot_type);
7498
7492
  const size_t nbw1 = lm_ggml_row_size(vec_dot_type, ne10);
7499
7493
  const size_t nbw2 = nbw1*ne11;
7500
7494
  const size_t nbw3 = nbw2*ne12;
@@ -7502,6 +7496,7 @@ UseGgmlGemm1:;
7502
7496
  assert(params->wsize >= ne13*nbw3);
7503
7497
  LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
7504
7498
 
7499
+ #if 0
7505
7500
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7506
7501
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
7507
7502
  for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
@@ -7511,6 +7506,20 @@ UseGgmlGemm1:;
7511
7506
  }
7512
7507
  }
7513
7508
  }
7509
+ #else
7510
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
7511
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
7512
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
7513
+ size_t bs = lm_ggml_blck_size(vec_dot_type);
7514
+ int64_t ne10_block_start = (ith * ne10/bs) / nth;
7515
+ int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
7516
+ from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
7517
+ (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
7518
+ (ne10_block_end - ne10_block_start) * bs);
7519
+ }
7520
+ }
7521
+ }
7522
+ #endif
7514
7523
  }
7515
7524
 
7516
7525
  if (ith == 0) {
@@ -7565,7 +7574,7 @@ UseGgmlGemm2:;
7565
7574
  int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
7566
7575
 
7567
7576
  // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
7568
- // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
7577
+ // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggml-org/llama.cpp/pull/6915
7569
7578
  // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
7570
7579
  if (nchunk0 * nchunk1 < nth * 4 || lm_ggml_is_numa()) {
7571
7580
  // distribute the thread work across the inner or outer loop based on which one is larger
@@ -7598,7 +7607,6 @@ UseGgmlGemm2:;
7598
7607
  if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
7599
7608
  num_rows_per_vec_dot = 1;
7600
7609
  }
7601
-
7602
7610
  lm_ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
7603
7611
 
7604
7612
  if (nth >= nchunk0 * nchunk1) {
@@ -7611,6 +7619,84 @@ UseGgmlGemm2:;
7611
7619
 
7612
7620
  // lm_ggml_compute_forward_mul_mat_id
7613
7621
 
7622
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
7623
+
7624
+ struct mmid_row_mapping {
7625
+ int32_t i1;
7626
+ int32_t i2;
7627
+ };
7628
+
7629
+ static void lm_ggml_compute_forward_mul_mat_id_one_chunk(
7630
+ struct lm_ggml_tensor * dst,
7631
+ const struct lm_ggml_tensor * src0,
7632
+ const struct lm_ggml_tensor * src1,
7633
+ const struct lm_ggml_tensor * ids,
7634
+ const int64_t cur_a,
7635
+ const int64_t ir0_start,
7636
+ const int64_t ir0_end,
7637
+ const int64_t ir1_start,
7638
+ const int64_t ir1_end,
7639
+ const char * src0_cur,
7640
+ const struct mmid_row_mapping * matrix_rows,
7641
+ const size_t row_size,
7642
+ const bool src1_cont,
7643
+ const void * wdata) {
7644
+
7645
+ LM_GGML_TENSOR_BINARY_OP_LOCALS
7646
+
7647
+ const enum lm_ggml_type type = src0->type;
7648
+
7649
+ lm_ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
7650
+ enum lm_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
7651
+
7652
+ const int64_t blck_0 = 16;
7653
+ const int64_t blck_1 = 16;
7654
+
7655
+ float tmp[16];
7656
+
7657
+ for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
7658
+ for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
7659
+ for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {
7660
+ const int64_t _i12 = ir1; // logical row index for this expert
7661
+
7662
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
7663
+ const int id = row_mapping.i1; // selected expert index
7664
+
7665
+ const int64_t i11 = id % ne11;
7666
+ const int64_t i12 = row_mapping.i2; // row index in src1
7667
+
7668
+ const int64_t i1 = id; // selected expert index
7669
+ const int64_t i2 = i12; // row
7670
+
7671
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
7672
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
7673
+ // the original src1 data pointer, so we should index using the indices directly
7674
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
7675
+ const char * src1_col = (const char *) wdata +
7676
+ (src1_cont || src1->type != vec_dot_type
7677
+ ? (i11 + i12*ne11)*row_size
7678
+ : (i11*nb11 + i12*nb12));
7679
+
7680
+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
7681
+
7682
+ for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
7683
+ vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
7684
+ }
7685
+
7686
+ memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
7687
+ }
7688
+ }
7689
+ }
7690
+ }
7691
+
7692
+ static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
7693
+
7694
+ void * ptr = *p;
7695
+ ptr = (void *) LM_GGML_PAD((uintptr_t) ptr, align);
7696
+ *p = (void *) ((char *) ptr + size);
7697
+ return ptr;
7698
+ }
7699
+
7614
7700
  static void lm_ggml_compute_forward_mul_mat_id(
7615
7701
  const struct lm_ggml_compute_params * params,
7616
7702
  struct lm_ggml_tensor * dst) {
@@ -7628,7 +7714,6 @@ static void lm_ggml_compute_forward_mul_mat_id(
7628
7714
 
7629
7715
  const bool src1_cont = lm_ggml_is_contiguous(src1);
7630
7716
 
7631
- lm_ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
7632
7717
  enum lm_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
7633
7718
  lm_ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
7634
7719
 
@@ -7646,21 +7731,27 @@ static void lm_ggml_compute_forward_mul_mat_id(
7646
7731
  const int n_ids = ids->ne[0]; // n_expert_used
7647
7732
  const int n_as = ne02; // n_expert
7648
7733
 
7649
- char * wdata_src1_end = (src1->type == vec_dot_type) ?
7650
- (char *) params->wdata :
7651
- (char *) params->wdata + LM_GGML_PAD(lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(src1)), sizeof(int64_t));
7734
+ void * wdata_cur = params->wdata;
7652
7735
 
7653
- struct mmid_row_mapping {
7654
- int32_t i1;
7655
- int32_t i2;
7656
- };
7736
+ if (src1->type != vec_dot_type) {
7737
+ incr_ptr_aligned(&wdata_cur, lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(src1)), sizeof(int64_t));
7738
+ }
7739
+
7740
+ int64_t * matrix_row_counts = // [n_as]
7741
+ incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t));
7657
7742
 
7658
- int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
7659
- struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
7743
+ struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]
7744
+ incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t));
7745
+
7746
+ char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as]
7747
+ incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE);
7748
+
7749
+ LM_GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata));
7660
7750
 
7661
7751
  if (src1->type != vec_dot_type) {
7662
7752
  char * wdata = params->wdata;
7663
7753
 
7754
+ const size_t nbw0 = lm_ggml_type_size(vec_dot_type);
7664
7755
  const size_t nbw1 = lm_ggml_row_size(vec_dot_type, ne10);
7665
7756
  const size_t nbw2 = nbw1*ne11;
7666
7757
  const size_t nbw3 = nbw2*ne12;
@@ -7668,19 +7759,32 @@ static void lm_ggml_compute_forward_mul_mat_id(
7668
7759
  assert(params->wsize >= ne13*nbw3);
7669
7760
  LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
7670
7761
 
7762
+ #if 0
7671
7763
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7672
- for (int64_t i12 = 0; i12 < ne12; ++i12) {
7673
- for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
7764
+ for (int64_t i12 = ith; i12 < ne12; i12 += nth) {
7765
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
7674
7766
  from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
7675
7767
  (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
7676
7768
  ne10);
7677
7769
  }
7678
7770
  }
7679
7771
  }
7772
+ #else
7773
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
7774
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
7775
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
7776
+ size_t bs = lm_ggml_blck_size(vec_dot_type);
7777
+ int64_t ne10_block_start = (ith * ne10/bs) / nth;
7778
+ int64_t ne10_block_end = ((ith + 1) * ne10/bs) / nth;
7779
+ from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
7780
+ (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
7781
+ (ne10_block_end - ne10_block_start) * bs);
7782
+ }
7783
+ }
7784
+ }
7785
+ #endif
7680
7786
  }
7681
7787
 
7682
- #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
7683
-
7684
7788
  if (ith == 0) {
7685
7789
  // initialize matrix_row_counts
7686
7790
  memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
@@ -7698,9 +7802,14 @@ static void lm_ggml_compute_forward_mul_mat_id(
7698
7802
  }
7699
7803
  }
7700
7804
 
7805
+ // reset current_chunk
7806
+ for (int cur_a = ith; cur_a < n_as; cur_a += nth) {
7807
+ atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
7808
+ *current_chunk_ctr = nth;
7809
+ }
7810
+
7701
7811
  lm_ggml_barrier(params->threadpool);
7702
7812
 
7703
- // compute each matrix multiplication in sequence
7704
7813
  for (int cur_a = 0; cur_a < n_as; ++cur_a) {
7705
7814
  const int64_t cne1 = matrix_row_counts[cur_a];
7706
7815
 
@@ -7708,84 +7817,64 @@ static void lm_ggml_compute_forward_mul_mat_id(
7708
7817
  continue;
7709
7818
  }
7710
7819
 
7711
- const char * src0_cur = (const char *) src0->data + cur_a*nb02;
7712
-
7713
- const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
7820
+ const char * src0_cur = (const char *) src0->data + cur_a * nb02;
7821
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
7714
7822
  const size_t row_size = lm_ggml_row_size(vec_dot_type, ne10);
7715
7823
 
7716
- const int64_t nr0 = ne01; // src0 rows
7717
- const int64_t nr1 = cne1; // src1 rows
7718
-
7719
- // distribute the thread work across the inner or outer loop based on which one is larger
7720
-
7721
- const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
7722
- const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
7723
-
7724
- const int64_t ith0 = ith % nth0;
7725
- const int64_t ith1 = ith / nth0;
7726
-
7727
- const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
7728
- const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
7824
+ const int64_t nr0 = ne01;
7825
+ const int64_t nr1 = cne1;
7729
7826
 
7730
- const int64_t ir010 = dr0*ith0;
7731
- const int64_t ir011 = MIN(ir010 + dr0, nr0);
7732
-
7733
- const int64_t ir110 = dr1*ith1;
7734
- const int64_t ir111 = MIN(ir110 + dr1, nr1);
7735
-
7736
- // threads with no work simply yield (not sure if it helps)
7737
- //if (ir010 >= ir011 || ir110 >= ir111) {
7738
- // sched_yield();
7739
- // continue;
7740
- //}
7827
+ int chunk_size = 16;
7828
+ if (nr0 == 1 || nr1 == 1) {
7829
+ chunk_size = 64;
7830
+ }
7741
7831
 
7742
- // block-tiling attempt
7743
- const int64_t blck_0 = 16;
7744
- const int64_t blck_1 = 16;
7832
+ #if defined(__aarch64__)
7833
+ // disable for ARM
7834
+ const bool disable_chunking = true;
7835
+ #else
7836
+ // disable for NUMA
7837
+ const bool disable_chunking = lm_ggml_is_numa();
7838
+ #endif // defined(__aarch64__)
7745
7839
 
7746
- // attempt to reduce false-sharing (does not seem to make a difference)
7747
- float tmp[16];
7840
+ int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
7841
+ int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
7748
7842
 
7749
- for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
7750
- for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
7751
- for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
7752
- const int64_t _i12 = ir1; // logical row index for this expert
7843
+ if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
7844
+ nchunk0 = nr0 > nr1 ? nth : 1;
7845
+ nchunk1 = nr0 > nr1 ? 1 : nth;
7846
+ }
7753
7847
 
7754
- struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
7755
- const int id = row_mapping.i1; // selected expert index
7848
+ const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
7849
+ const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
7756
7850
 
7757
- const int64_t i11 = id % ne11;
7758
- const int64_t i12 = row_mapping.i2; // row index in src1
7851
+ int current_chunk = ith;
7759
7852
 
7760
- const int64_t i1 = id; // selected expert index
7761
- const int64_t i2 = i12; // row
7853
+ atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
7762
7854
 
7763
- // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
7764
- // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
7765
- // the original src1 data pointer, so we should index using the indices directly
7766
- // TODO: this is a bit of a hack, we should probably have a better way to handle this
7767
- const char * src1_col = (const char *) wdata +
7768
- (src1_cont || src1->type != vec_dot_type
7769
- ? (i11 + i12*ne11)*row_size
7770
- : (i11*nb11 + i12*nb12));
7855
+ while (current_chunk < nchunk0 * nchunk1) {
7856
+ const int64_t ith0 = current_chunk % nchunk0;
7857
+ const int64_t ith1 = current_chunk / nchunk0;
7771
7858
 
7772
- float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
7859
+ const int64_t ir0_start = dr0 * ith0;
7860
+ const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
7773
7861
 
7774
- //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
7775
- // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
7776
- //}
7862
+ const int64_t ir1_start = dr1 * ith1;
7863
+ const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
7777
7864
 
7778
- for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
7779
- vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
7780
- }
7865
+ lm_ggml_compute_forward_mul_mat_id_one_chunk(
7866
+ dst, src0, src1, ids, cur_a,
7867
+ ir0_start, ir0_end, ir1_start, ir1_end,
7868
+ src0_cur, matrix_rows, row_size, src1_cont, wdata
7869
+ );
7781
7870
 
7782
- memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
7783
- }
7871
+ if (nth >= nchunk0 * nchunk1) {
7872
+ break;
7784
7873
  }
7874
+
7875
+ current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);
7785
7876
  }
7786
7877
  }
7787
-
7788
- #undef MMID_MATRIX_ROW
7789
7878
  }
7790
7879
 
7791
7880
  // lm_ggml_compute_forward_out_prod
@@ -7882,7 +7971,7 @@ static void lm_ggml_compute_forward_out_prod_f32(
7882
7971
 
7883
7972
  float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
7884
7973
  float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
7885
- float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7974
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7886
7975
 
7887
7976
  lm_ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
7888
7977
  }
@@ -7891,7 +7980,7 @@ static void lm_ggml_compute_forward_out_prod_f32(
7891
7980
 
7892
7981
  float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
7893
7982
  float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
7894
- float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7983
+ float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7895
7984
 
7896
7985
  lm_ggml_vec_mad_f32(ne0, d, s0, *s1);
7897
7986
  }
@@ -9079,10 +9168,6 @@ static void lm_ggml_compute_forward_clamp_f32(
9079
9168
 
9080
9169
  const struct lm_ggml_tensor * src0 = dst->src[0];
9081
9170
 
9082
- if (params->ith != 0) {
9083
- return;
9084
- }
9085
-
9086
9171
  float min;
9087
9172
  float max;
9088
9173
  memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
@@ -13722,14 +13807,19 @@ struct lm_ggml_cplan lm_ggml_graph_plan(
13722
13807
  cur = 0;
13723
13808
  const struct lm_ggml_tensor * src0 = node->src[0];
13724
13809
  const struct lm_ggml_tensor * src1 = node->src[1];
13810
+ const struct lm_ggml_tensor * ids = node->src[2];
13725
13811
  const enum lm_ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
13812
+ const int n_as = src0->ne[2];
13813
+ // src1
13726
13814
  if (src1->type != vec_dot_type) {
13727
- cur += lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(src1));
13815
+ cur += lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(src1)) + sizeof(int64_t);
13728
13816
  }
13729
- const int n_as = src0->ne[2];
13730
- cur += LM_GGML_PAD(cur, sizeof(int64_t)); // align
13731
- cur += n_as * sizeof(int64_t); // matrix_row_counts
13732
- cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
13817
+ // matrix_row_counts
13818
+ cur += n_as * sizeof(int64_t) + sizeof(int64_t);
13819
+ // matrix_rows
13820
+ cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t);
13821
+ // atomic_current_chunk
13822
+ cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE;
13733
13823
  } break;
13734
13824
  case LM_GGML_OP_OUT_PROD:
13735
13825
  {
@@ -13850,20 +13940,24 @@ static thread_ret_t lm_ggml_graph_compute_thread(void * data) {
13850
13940
  /*.threadpool=*/ tp,
13851
13941
  };
13852
13942
 
13853
- for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
13943
+ for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
13854
13944
  struct lm_ggml_tensor * node = cgraph->nodes[node_n];
13855
13945
 
13856
13946
  lm_ggml_compute_forward(&params, node);
13857
13947
 
13858
13948
  if (state->ith == 0 && cplan->abort_callback &&
13859
13949
  cplan->abort_callback(cplan->abort_callback_data)) {
13860
- tp->abort = true;
13950
+ atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
13861
13951
  tp->ec = LM_GGML_STATUS_ABORTED;
13862
13952
  }
13863
13953
 
13864
- lm_ggml_barrier(state->threadpool);
13954
+ if (node_n + 1 < cgraph->n_nodes) {
13955
+ lm_ggml_barrier(state->threadpool);
13956
+ }
13865
13957
  }
13866
13958
 
13959
+ lm_ggml_barrier(state->threadpool);
13960
+
13867
13961
  return 0;
13868
13962
  }
13869
13963
 
@@ -14030,7 +14124,7 @@ static struct lm_ggml_threadpool * lm_ggml_threadpool_new_impl(
14030
14124
  threadpool->current_chunk = 0;
14031
14125
  threadpool->stop = false;
14032
14126
  threadpool->pause = tpp->paused;
14033
- threadpool->abort = false;
14127
+ threadpool->abort = -1;
14034
14128
  threadpool->workers = NULL;
14035
14129
  threadpool->n_threads_max = tpp->n_threads;
14036
14130
  threadpool->n_threads_cur = tpp->n_threads;
@@ -14109,7 +14203,7 @@ enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct
14109
14203
  threadpool->cgraph = cgraph;
14110
14204
  threadpool->cplan = cplan;
14111
14205
  threadpool->current_chunk = 0;
14112
- threadpool->abort = false;
14206
+ threadpool->abort = -1;
14113
14207
  threadpool->ec = LM_GGML_STATUS_SUCCESS;
14114
14208
  }
14115
14209
 
package/cpp/ggml-cpu.cpp CHANGED
@@ -283,14 +283,14 @@ struct lm_ggml_backend_cpu_device_context {
283
283
  &hKey) == ERROR_SUCCESS) {
284
284
  DWORD cpu_brand_size = 0;
285
285
  if (RegQueryValueExA(hKey,
286
- TEXT("ProcessorNameString"),
286
+ "ProcessorNameString",
287
287
  NULL,
288
288
  NULL,
289
289
  NULL,
290
290
  &cpu_brand_size) == ERROR_SUCCESS) {
291
291
  description.resize(cpu_brand_size);
292
292
  if (RegQueryValueExA(hKey,
293
- TEXT("ProcessorNameString"),
293
+ "ProcessorNameString",
294
294
  NULL,
295
295
  NULL,
296
296
  (LPBYTE)&description[0], // NOLINT
@@ -415,7 +415,8 @@ static bool lm_ggml_backend_cpu_device_supports_op(lm_ggml_backend_dev_t dev, co
415
415
  case LM_GGML_OP_IM2COL_BACK:
416
416
  return src0->type == LM_GGML_TYPE_F32 && src1->type == LM_GGML_TYPE_F32;
417
417
  case LM_GGML_OP_OUT_PROD:
418
- return (src0->type == LM_GGML_TYPE_F32 || lm_ggml_is_quantized(src0->type)) && src1->type == LM_GGML_TYPE_F32;
418
+ return (src0->type == LM_GGML_TYPE_F32 || (lm_ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
419
+ src1->type == LM_GGML_TYPE_F32 && op->type == LM_GGML_TYPE_F32;
419
420
  default:
420
421
  return true;
421
422
  }
@@ -532,9 +533,6 @@ static lm_ggml_backend_feature * lm_ggml_backend_cpu_get_features(lm_ggml_backen
532
533
  if (lm_ggml_cpu_has_dotprod()) {
533
534
  features.push_back({ "DOTPROD", "1" });
534
535
  }
535
- if (lm_ggml_cpu_has_matmul_int8()) {
536
- features.push_back({ "MATMUL_INT8", "1" });
537
- }
538
536
  if (lm_ggml_cpu_get_sve_cnt() > 0) {
539
537
  static std::string sve_cnt = std::to_string(lm_ggml_cpu_get_sve_cnt());
540
538
  features.push_back({ "SVE_CNT", sve_cnt.c_str() });
package/cpp/ggml-cpu.h CHANGED
@@ -8,7 +8,7 @@ extern "C" {
8
8
  #endif
9
9
 
10
10
  // the compute plan that needs to be prepared for lm_ggml_graph_compute()
11
- // since https://github.com/ggerganov/ggml/issues/287
11
+ // since https://github.com/ggml-org/ggml/issues/287
12
12
  struct lm_ggml_cplan {
13
13
  size_t work_size; // size of work buffer, calculated by `lm_ggml_graph_plan()`
14
14
  uint8_t * work_data; // work buffer, to be allocated by caller before calling to `lm_ggml_graph_compute()`