cui-llama.rn 1.4.3 → 1.4.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. package/README.md +93 -114
  2. package/android/src/main/CMakeLists.txt +5 -0
  3. package/android/src/main/build-arm64/CMakeCache.txt +429 -0
  4. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCCompiler.cmake +21 -21
  5. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeCXXCompiler.cmake +101 -0
  6. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_C.bin +0 -0
  7. package/android/src/main/build-arm64/CMakeFiles/3.31.4/CMakeDetermineCompilerABI_CXX.bin +0 -0
  8. package/android/src/main/build-arm64/CMakeFiles/CMakeConfigureLog.yaml +376 -0
  9. package/android/src/main/build-arm64/CMakeFiles/CMakeDirectoryInformation.cmake +16 -0
  10. package/android/src/main/build-arm64/CMakeFiles/Makefile.cmake +165 -0
  11. package/android/src/main/build-arm64/CMakeFiles/Makefile2 +297 -0
  12. package/android/src/main/build-arm64/CMakeFiles/Progress/1 +1 -0
  13. package/android/src/main/build-arm64/CMakeFiles/Progress/2 +1 -0
  14. package/android/src/main/build-arm64/CMakeFiles/Progress/3 +1 -0
  15. package/android/src/main/build-arm64/CMakeFiles/Progress/4 +1 -0
  16. package/android/src/main/build-arm64/CMakeFiles/Progress/5 +1 -0
  17. package/android/src/main/build-arm64/CMakeFiles/Progress/6 +1 -0
  18. package/android/src/main/build-arm64/CMakeFiles/Progress/count.txt +1 -0
  19. package/android/src/main/build-arm64/CMakeFiles/TargetDirectories.txt +8 -0
  20. package/android/src/main/build-arm64/CMakeFiles/cmake.check_cache +1 -0
  21. package/android/src/main/build-arm64/CMakeFiles/progress.marks +1 -0
  22. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o +0 -0
  23. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-alloc.c.o.d +58 -0
  24. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o +0 -0
  25. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend-reg.cpp.o.d +756 -0
  26. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o +0 -0
  27. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-backend.cpp.o.d +709 -0
  28. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o +0 -0
  29. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-aarch64.cpp.o.d +714 -0
  30. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o +0 -0
  31. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-quants.c.o.d +62 -0
  32. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o +0 -0
  33. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu-traits.cpp.o.d +708 -0
  34. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o +0 -0
  35. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.c.o.d +113 -0
  36. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o +0 -0
  37. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-cpu.cpp.o.d +713 -0
  38. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o +0 -0
  39. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-opt.cpp.o.d +763 -0
  40. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o +0 -0
  41. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-quants.c.o.d +61 -0
  42. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o +0 -0
  43. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml-threading.cpp.o.d +707 -0
  44. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o +0 -0
  45. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/ggml.c.o.d +104 -0
  46. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o +0 -0
  47. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/gguf.cpp.o.d +714 -0
  48. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o +0 -0
  49. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/D_/dev/react-native/cui-llama.rn/cpp/log.cpp.o.d +723 -0
  50. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/DependInfo.cmake +62 -0
  51. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/build.make +722 -0
  52. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/cmake_clean.cmake +89 -0
  53. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.make +2 -0
  54. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/compiler_depend.ts +2 -0
  55. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/depend.make +2 -0
  56. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/flags.make +17 -0
  57. package/android/src/main/build-arm64/CMakeFiles/rnllama.dir/progress.make +41 -0
  58. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/DependInfo.cmake +62 -0
  59. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/build.make +722 -0
  60. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/cmake_clean.cmake +89 -0
  61. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.make +2 -0
  62. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/compiler_depend.ts +2 -0
  63. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/depend.make +2 -0
  64. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/flags.make +17 -0
  65. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8.dir/progress.make +41 -0
  66. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/DependInfo.cmake +62 -0
  67. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/build.make +722 -0
  68. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/cmake_clean.cmake +89 -0
  69. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.make +2 -0
  70. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/compiler_depend.ts +2 -0
  71. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/depend.make +2 -0
  72. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/flags.make +17 -0
  73. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2.dir/progress.make +41 -0
  74. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/DependInfo.cmake +62 -0
  75. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/build.make +722 -0
  76. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/cmake_clean.cmake +89 -0
  77. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.make +2 -0
  78. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/compiler_depend.ts +2 -0
  79. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/depend.make +2 -0
  80. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/flags.make +17 -0
  81. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod.dir/progress.make +41 -0
  82. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/DependInfo.cmake +62 -0
  83. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/build.make +722 -0
  84. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/cmake_clean.cmake +89 -0
  85. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.make +2 -0
  86. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/compiler_depend.ts +2 -0
  87. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/depend.make +2 -0
  88. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/flags.make +17 -0
  89. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_dotprod_i8mm.dir/progress.make +41 -0
  90. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/DependInfo.cmake +62 -0
  91. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/build.make +722 -0
  92. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/cmake_clean.cmake +89 -0
  93. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.make +2 -0
  94. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/compiler_depend.ts +2 -0
  95. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/depend.make +2 -0
  96. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/flags.make +17 -0
  97. package/android/src/main/build-arm64/CMakeFiles/rnllama_v8_2_i8mm.dir/progress.make +41 -0
  98. package/android/src/main/build-arm64/Makefile +1862 -0
  99. package/android/src/main/build-arm64/cmake_install.cmake +66 -0
  100. package/android/src/main/java/com/rnllama/LlamaContext.java +91 -17
  101. package/android/src/main/java/com/rnllama/RNLlama.java +37 -4
  102. package/android/src/main/jni-utils.h +6 -0
  103. package/android/src/main/jni.cpp +287 -31
  104. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  105. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  106. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  107. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  108. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  109. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  110. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  111. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  112. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +7 -2
  113. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +7 -2
  114. package/cpp/chat-template.hpp +529 -0
  115. package/cpp/chat.cpp +1085 -0
  116. package/cpp/chat.hpp +55 -0
  117. package/cpp/common.cpp +159 -36
  118. package/cpp/common.h +64 -19
  119. package/cpp/ggml-alloc.c +1 -13
  120. package/cpp/ggml-common.h +0 -2
  121. package/cpp/ggml-cpu-impl.h +6 -12
  122. package/cpp/ggml-cpu-quants.c +937 -340
  123. package/cpp/ggml-cpu.c +207 -113
  124. package/cpp/ggml-cpu.cpp +4 -6
  125. package/cpp/ggml-cpu.h +1 -1
  126. package/cpp/ggml-metal.h +66 -66
  127. package/cpp/ggml-metal.m +141 -23
  128. package/cpp/ggml.c +24 -14
  129. package/cpp/ggml.h +2 -2
  130. package/cpp/json-schema-to-grammar.cpp +46 -66
  131. package/cpp/json-schema-to-grammar.h +15 -1
  132. package/cpp/llama-arch.cpp +7 -2
  133. package/cpp/llama-arch.h +3 -1
  134. package/cpp/llama-chat.cpp +10 -1
  135. package/cpp/llama-chat.h +1 -0
  136. package/cpp/llama-grammar.cpp +86 -6
  137. package/cpp/llama-grammar.h +22 -1
  138. package/cpp/llama-impl.h +6 -6
  139. package/cpp/llama-kv-cache.h +1 -1
  140. package/cpp/llama-mmap.h +1 -0
  141. package/cpp/llama-model-loader.cpp +1 -1
  142. package/cpp/llama-model.cpp +32 -6
  143. package/cpp/llama-sampling.cpp +178 -61
  144. package/cpp/llama-vocab.cpp +8 -3
  145. package/cpp/llama.cpp +188 -128
  146. package/cpp/llama.h +27 -10
  147. package/cpp/log.cpp +32 -10
  148. package/cpp/log.h +12 -1
  149. package/cpp/minja.hpp +2883 -0
  150. package/cpp/rn-llama.cpp +82 -5
  151. package/cpp/rn-llama.h +16 -1
  152. package/cpp/sampling.cpp +68 -41
  153. package/cpp/sampling.h +3 -0
  154. package/cpp/sgemm.cpp +9 -8
  155. package/cpp/unicode.cpp +9 -2
  156. package/ios/CMakeLists.txt +6 -0
  157. package/ios/RNLlama.h +0 -8
  158. package/ios/RNLlama.mm +27 -3
  159. package/ios/RNLlamaContext.h +10 -1
  160. package/ios/RNLlamaContext.mm +269 -57
  161. package/jest/mock.js +21 -2
  162. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  163. package/lib/commonjs/grammar.js +3 -0
  164. package/lib/commonjs/grammar.js.map +1 -1
  165. package/lib/commonjs/index.js +87 -13
  166. package/lib/commonjs/index.js.map +1 -1
  167. package/lib/module/NativeRNLlama.js.map +1 -1
  168. package/lib/module/grammar.js +3 -0
  169. package/lib/module/grammar.js.map +1 -1
  170. package/lib/module/index.js +86 -13
  171. package/lib/module/index.js.map +1 -1
  172. package/lib/typescript/NativeRNLlama.d.ts +107 -2
  173. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  174. package/lib/typescript/grammar.d.ts.map +1 -1
  175. package/lib/typescript/index.d.ts +32 -7
  176. package/lib/typescript/index.d.ts.map +1 -1
  177. package/llama-rn.podspec +1 -1
  178. package/package.json +3 -2
  179. package/src/NativeRNLlama.ts +115 -3
  180. package/src/grammar.ts +3 -0
  181. package/src/index.ts +138 -21
package/cpp/ggml-metal.h CHANGED
@@ -1,66 +1,66 @@
1
- // Note: this description is outdated
2
- //
3
- // An interface allowing to compute lm_ggml_cgraph with Metal
4
- //
5
- // This is a fully functional interface that extends ggml with GPU support for Apple devices.
6
- // A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
7
- //
8
- // How it works?
9
- //
10
- // As long as your program can create and evaluate a lm_ggml_cgraph on the CPU, you can use this
11
- // interface to evaluate the same graph on the GPU. Instead of using lm_ggml_graph_compute(), you
12
- // use lm_ggml_metal_graph_compute() (or lm_ggml_vulkan_graph_compute(), etc.)
13
- //
14
- // You only need to make sure that all memory buffers that you used during the graph creation
15
- // are mapped to the device memory with the lm_ggml_metal_add_buffer() function. This mapping is
16
- // used during the graph evaluation to determine the arguments of the compute kernels.
17
- //
18
- // Synchronization between device and host memory (for example for input and output tensors)
19
- // is done with the lm_ggml_metal_set_tensor() and lm_ggml_metal_get_tensor() functions.
20
- //
21
-
22
- #pragma once
23
-
24
- #include "ggml.h"
25
- #include "ggml-backend.h"
26
-
27
- #include <stddef.h>
28
- #include <stdbool.h>
29
-
30
- struct lm_ggml_tensor;
31
- struct lm_ggml_cgraph;
32
-
33
- #ifdef __cplusplus
34
- extern "C" {
35
- #endif
36
-
37
- //
38
- // backend API
39
- // user-code should use only these functions
40
- //
41
-
42
- LM_GGML_BACKEND_API lm_ggml_backend_t lm_ggml_backend_metal_init(void);
43
-
44
- LM_GGML_BACKEND_API bool lm_ggml_backend_is_metal(lm_ggml_backend_t backend);
45
-
46
- LM_GGML_DEPRECATED(
47
- LM_GGML_BACKEND_API lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
48
- "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
49
-
50
- LM_GGML_BACKEND_API void lm_ggml_backend_metal_set_abort_callback(lm_ggml_backend_t backend, lm_ggml_abort_callback abort_callback, void * user_data);
51
-
52
- LM_GGML_BACKEND_API lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void);
53
-
54
- // helper to check if the device supports a specific family
55
- // ideally, the user code should be doing these checks
56
- // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
57
- LM_GGML_BACKEND_API bool lm_ggml_backend_metal_supports_family(lm_ggml_backend_t backend, int family);
58
-
59
- // capture all command buffers committed the next time `lm_ggml_backend_graph_compute` is called
60
- LM_GGML_BACKEND_API void lm_ggml_backend_metal_capture_next_compute(lm_ggml_backend_t backend);
61
-
62
- LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void);
63
-
64
- #ifdef __cplusplus
65
- }
66
- #endif
1
+ // Note: this description is outdated
2
+ //
3
+ // An interface allowing to compute lm_ggml_cgraph with Metal
4
+ //
5
+ // This is a fully functional interface that extends ggml with GPU support for Apple devices.
6
+ // A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
7
+ //
8
+ // How it works?
9
+ //
10
+ // As long as your program can create and evaluate a lm_ggml_cgraph on the CPU, you can use this
11
+ // interface to evaluate the same graph on the GPU. Instead of using lm_ggml_graph_compute(), you
12
+ // use lm_ggml_metal_graph_compute() (or lm_ggml_vulkan_graph_compute(), etc.)
13
+ //
14
+ // You only need to make sure that all memory buffers that you used during the graph creation
15
+ // are mapped to the device memory with the lm_ggml_metal_add_buffer() function. This mapping is
16
+ // used during the graph evaluation to determine the arguments of the compute kernels.
17
+ //
18
+ // Synchronization between device and host memory (for example for input and output tensors)
19
+ // is done with the lm_ggml_metal_set_tensor() and lm_ggml_metal_get_tensor() functions.
20
+ //
21
+
22
+ #pragma once
23
+
24
+ #include "ggml.h"
25
+ #include "ggml-backend.h"
26
+
27
+ #include <stddef.h>
28
+ #include <stdbool.h>
29
+
30
+ struct lm_ggml_tensor;
31
+ struct lm_ggml_cgraph;
32
+
33
+ #ifdef __cplusplus
34
+ extern "C" {
35
+ #endif
36
+
37
+ //
38
+ // backend API
39
+ // user-code should use only these functions
40
+ //
41
+
42
+ LM_GGML_BACKEND_API lm_ggml_backend_t lm_ggml_backend_metal_init(void);
43
+
44
+ LM_GGML_BACKEND_API bool lm_ggml_backend_is_metal(lm_ggml_backend_t backend);
45
+
46
+ LM_GGML_DEPRECATED(
47
+ LM_GGML_BACKEND_API lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
48
+ "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
49
+
50
+ LM_GGML_BACKEND_API void lm_ggml_backend_metal_set_abort_callback(lm_ggml_backend_t backend, lm_ggml_abort_callback abort_callback, void * user_data);
51
+
52
+ LM_GGML_BACKEND_API lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void);
53
+
54
+ // helper to check if the device supports a specific family
55
+ // ideally, the user code should be doing these checks
56
+ // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
57
+ LM_GGML_BACKEND_API bool lm_ggml_backend_metal_supports_family(lm_ggml_backend_t backend, int family);
58
+
59
+ // capture all command buffers committed the next time `lm_ggml_backend_graph_compute` is called
60
+ LM_GGML_BACKEND_API void lm_ggml_backend_metal_capture_next_compute(lm_ggml_backend_t backend);
61
+
62
+ LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void);
63
+
64
+ #ifdef __cplusplus
65
+ }
66
+ #endif
package/cpp/ggml-metal.m CHANGED
@@ -19,7 +19,17 @@
19
19
  // max number of MTLCommandBuffer used to submit a graph for processing
20
20
  #define LM_GGML_METAL_MAX_COMMAND_BUFFERS 8
21
21
 
22
- #define UNUSED(x) (void)(x)
22
+ #ifndef TARGET_OS_VISION
23
+ #define TARGET_OS_VISION 0
24
+ #endif
25
+
26
+ // create residency sets only on macOS >= 15.0
27
+ #if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
28
+ TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \
29
+ TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \
30
+ TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000
31
+ #define LM_GGML_METAL_HAS_RESIDENCY_SETS 1
32
+ #endif
23
33
 
24
34
  // globals
25
35
 
@@ -39,6 +49,7 @@ static struct lm_ggml_backend_metal_device_context {
39
49
 
40
50
  bool has_simdgroup_reduction;
41
51
  bool has_simdgroup_mm;
52
+ bool has_residency_sets;
42
53
  bool has_bfloat;
43
54
  bool use_bfloat;
44
55
 
@@ -48,6 +59,7 @@ static struct lm_ggml_backend_metal_device_context {
48
59
  /*.mtl_device_ref_count =*/ 0,
49
60
  /*.has_simdgroup_reduction =*/ false,
50
61
  /*.has_simdgroup_mm =*/ false,
62
+ /*.has_residency_sets =*/ false,
51
63
  /*.has_bfloat =*/ false,
52
64
  /*.use_bfloat =*/ false,
53
65
  /*.name =*/ "",
@@ -59,12 +71,18 @@ static id<MTLDevice> lm_ggml_backend_metal_device_acq(struct lm_ggml_backend_met
59
71
 
60
72
  if (ctx->mtl_device == nil) {
61
73
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
74
+ }
62
75
 
76
+ if (ctx->mtl_device) {
63
77
  ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
64
78
  ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
65
79
 
66
80
  ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
67
81
 
82
+ #if defined(LM_GGML_METAL_HAS_RESIDENCY_SETS)
83
+ ctx->has_residency_sets = getenv("LM_GGML_METAL_NO_RESIDENCY") == NULL;
84
+ #endif
85
+
68
86
  ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
69
87
  ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
70
88
 
@@ -90,8 +108,10 @@ static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device
90
108
  ctx->mtl_device_ref_count--;
91
109
 
92
110
  if (ctx->mtl_device_ref_count == 0) {
93
- [ctx->mtl_device release];
94
- ctx->mtl_device = nil;
111
+ if (ctx->mtl_device) {
112
+ [ctx->mtl_device release];
113
+ ctx->mtl_device = nil;
114
+ }
95
115
  }
96
116
  }
97
117
 
@@ -483,6 +503,11 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
483
503
  LM_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
484
504
 
485
505
  ctx->queue = [device newCommandQueue];
506
+ if (ctx->queue == nil) {
507
+ LM_GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
508
+ return NULL;
509
+ }
510
+
486
511
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
487
512
 
488
513
  id<MTLLibrary> metal_library;
@@ -509,7 +534,11 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
509
534
  const bool try_metallib = true;
510
535
  #endif
511
536
 
537
+ #if TARGET_OS_SIMULATOR
538
+ NSString * path_lib = [bundle pathForResource:@"ggml-llama-sim" ofType:@"metallib"];
539
+ #else
512
540
  NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"];
541
+ #endif
513
542
  if (path_lib == nil) {
514
543
  // Try to find the resource in the directory where the current binary located.
515
544
  NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
@@ -649,6 +678,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
649
678
 
650
679
  LM_GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
651
680
  LM_GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
681
+ LM_GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false");
652
682
  LM_GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
653
683
  LM_GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
654
684
  LM_GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
@@ -1035,8 +1065,70 @@ struct lm_ggml_backend_metal_buffer_context {
1035
1065
  // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
1036
1066
  int n_buffers;
1037
1067
  struct lm_ggml_backend_metal_buffer buffers[LM_GGML_METAL_MAX_BUFFERS];
1068
+
1069
+ // optional MTLResidencySet
1070
+ id rset;
1038
1071
  };
1039
1072
 
1073
+ // rset init
1074
+ static bool lm_ggml_backend_metal_buffer_rset_init(
1075
+ struct lm_ggml_backend_metal_buffer_context * ctx,
1076
+ struct lm_ggml_backend_metal_device_context * ctx_dev,
1077
+ id<MTLDevice> device) {
1078
+ ctx->rset = nil;
1079
+
1080
+ if (!ctx_dev->has_residency_sets) {
1081
+ return true;
1082
+ }
1083
+
1084
+ #if defined(LM_GGML_METAL_HAS_RESIDENCY_SETS)
1085
+ if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
1086
+ MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init];
1087
+ desc.label = @"lm_ggml_backend_metal";
1088
+ desc.initialCapacity = ctx->n_buffers;
1089
+
1090
+ NSError * error;
1091
+ ctx->rset = [device newResidencySetWithDescriptor:desc error:&error];
1092
+ if (error) {
1093
+ LM_GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
1094
+ [desc release];
1095
+ return false;
1096
+ }
1097
+
1098
+ [desc release];
1099
+
1100
+ for (int i = 0; i < ctx->n_buffers; i++) {
1101
+ [ctx->rset addAllocation:ctx->buffers[i].metal];
1102
+ }
1103
+
1104
+ [ctx->rset commit];
1105
+ [ctx->rset requestResidency];
1106
+
1107
+ return true;
1108
+ }
1109
+ #else
1110
+ LM_GGML_UNUSED(ctx_dev);
1111
+ LM_GGML_UNUSED(device);
1112
+ #endif
1113
+
1114
+ return true;
1115
+ }
1116
+
1117
+ // rset free
1118
+ static void lm_ggml_backend_metal_buffer_rset_free(struct lm_ggml_backend_metal_buffer_context * ctx) {
1119
+ #if defined(LM_GGML_METAL_HAS_RESIDENCY_SETS)
1120
+ if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
1121
+ if (ctx->rset) {
1122
+ [ctx->rset endResidency];
1123
+ [ctx->rset removeAllAllocations];
1124
+ [ctx->rset release];
1125
+ }
1126
+ }
1127
+ #else
1128
+ LM_GGML_UNUSED(ctx);
1129
+ #endif
1130
+ }
1131
+
1040
1132
  // finds the Metal buffer that contains the tensor data on the GPU device
1041
1133
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
1042
1134
  // Metal buffer based on the host memory pointer
@@ -1120,12 +1212,13 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1120
1212
  case LM_GGML_OP_SUM_ROWS:
1121
1213
  case LM_GGML_OP_SOFT_MAX:
1122
1214
  case LM_GGML_OP_GROUP_NORM:
1123
- return has_simdgroup_reduction;
1215
+ return has_simdgroup_reduction && lm_ggml_is_contiguous(op->src[0]);
1124
1216
  case LM_GGML_OP_RMS_NORM:
1125
- return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
1217
+ return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && lm_ggml_is_contiguous_1(op->src[0]));
1126
1218
  case LM_GGML_OP_ARGMAX:
1127
- case LM_GGML_OP_NORM:
1128
1219
  return true;
1220
+ case LM_GGML_OP_NORM:
1221
+ return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && lm_ggml_is_contiguous_1(op->src[0]));
1129
1222
  case LM_GGML_OP_ROPE:
1130
1223
  {
1131
1224
  const int mode = ((const int32_t *) op->op_params)[2];
@@ -1894,7 +1987,7 @@ static void lm_ggml_metal_encode_node(
1894
1987
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1895
1988
 
1896
1989
  // TODO: add lm_ggml_metal_kargs struct
1897
- // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
1990
+ // TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
1898
1991
  [encoder setComputePipelineState:pipeline];
1899
1992
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1900
1993
  if (id_src1) {
@@ -4176,6 +4269,8 @@ static void lm_ggml_backend_metal_buffer_free_buffer(lm_ggml_backend_buffer_t bu
4176
4269
  for (int i = 0; i < ctx->n_buffers; i++) {
4177
4270
  [ctx->buffers[i].metal release];
4178
4271
  }
4272
+
4273
+ lm_ggml_backend_metal_buffer_rset_free(ctx);
4179
4274
  lm_ggml_backend_metal_device_rel(buffer->buft->device->context);
4180
4275
 
4181
4276
  if (ctx->owned) {
@@ -4198,19 +4293,19 @@ static void * lm_ggml_backend_metal_buffer_get_base(lm_ggml_backend_buffer_t buf
4198
4293
  static void lm_ggml_backend_metal_buffer_memset_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
4199
4294
  memset((char *)tensor->data + offset, value, size);
4200
4295
 
4201
- UNUSED(buffer);
4296
+ LM_GGML_UNUSED(buffer);
4202
4297
  }
4203
4298
 
4204
4299
  static void lm_ggml_backend_metal_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
4205
4300
  memcpy((char *)tensor->data + offset, data, size);
4206
4301
 
4207
- UNUSED(buffer);
4302
+ LM_GGML_UNUSED(buffer);
4208
4303
  }
4209
4304
 
4210
4305
  static void lm_ggml_backend_metal_buffer_get_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
4211
4306
  memcpy(data, (const char *)tensor->data + offset, size);
4212
4307
 
4213
- UNUSED(buffer);
4308
+ LM_GGML_UNUSED(buffer);
4214
4309
  }
4215
4310
 
4216
4311
  static bool lm_ggml_backend_metal_buffer_cpy_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) {
@@ -4220,7 +4315,7 @@ static bool lm_ggml_backend_metal_buffer_cpy_tensor(lm_ggml_backend_buffer_t buf
4220
4315
  }
4221
4316
  return false;
4222
4317
 
4223
- UNUSED(buffer);
4318
+ LM_GGML_UNUSED(buffer);
4224
4319
  }
4225
4320
 
4226
4321
  static void lm_ggml_backend_metal_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) {
@@ -4246,7 +4341,7 @@ static struct lm_ggml_backend_buffer_i lm_ggml_backend_metal_buffer_i = {
4246
4341
  static const char * lm_ggml_backend_metal_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) {
4247
4342
  return "Metal";
4248
4343
 
4249
- UNUSED(buft);
4344
+ LM_GGML_UNUSED(buft);
4250
4345
  }
4251
4346
 
4252
4347
  static void lm_ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
@@ -4270,8 +4365,8 @@ static void lm_ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_
4270
4365
  }
4271
4366
  #endif
4272
4367
  #endif
4273
- UNUSED(device);
4274
- UNUSED(size_aligned);
4368
+ LM_GGML_UNUSED(device);
4369
+ LM_GGML_UNUSED(size_aligned);
4275
4370
  }
4276
4371
 
4277
4372
  static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) {
@@ -4284,7 +4379,8 @@ static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_alloc_buffer(l
4284
4379
  size_aligned += (size_page - (size_aligned % size_page));
4285
4380
  }
4286
4381
 
4287
- id<MTLDevice> device = lm_ggml_backend_metal_device_acq(buft->device->context);
4382
+ struct lm_ggml_backend_metal_device_context * ctx_dev = (struct lm_ggml_backend_metal_device_context *)buft->device->context;
4383
+ id<MTLDevice> device = lm_ggml_backend_metal_device_acq(ctx_dev);
4288
4384
 
4289
4385
  ctx->all_data = lm_ggml_metal_host_malloc(size_aligned);
4290
4386
  ctx->all_size = size_aligned;
@@ -4307,7 +4403,14 @@ static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_alloc_buffer(l
4307
4403
  if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
4308
4404
  LM_GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
4309
4405
  free(ctx);
4310
- lm_ggml_backend_metal_device_rel(buft->device->context);
4406
+ lm_ggml_backend_metal_device_rel(ctx_dev);
4407
+ return NULL;
4408
+ }
4409
+
4410
+ if (!lm_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
4411
+ LM_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
4412
+ free(ctx);
4413
+ lm_ggml_backend_metal_device_rel(ctx_dev);
4311
4414
  return NULL;
4312
4415
  }
4313
4416
 
@@ -4318,7 +4421,7 @@ static lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_type_alloc_buffer(l
4318
4421
 
4319
4422
  static size_t lm_ggml_backend_metal_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) {
4320
4423
  return 32;
4321
- UNUSED(buft);
4424
+ LM_GGML_UNUSED(buft);
4322
4425
  }
4323
4426
 
4324
4427
  static size_t lm_ggml_backend_metal_buffer_type_get_max_size(lm_ggml_backend_buffer_type_t buft) {
@@ -4328,13 +4431,13 @@ static size_t lm_ggml_backend_metal_buffer_type_get_max_size(lm_ggml_backend_buf
4328
4431
 
4329
4432
  return max_size;
4330
4433
 
4331
- UNUSED(buft);
4434
+ LM_GGML_UNUSED(buft);
4332
4435
  }
4333
4436
 
4334
4437
  static bool lm_ggml_backend_metal_buffer_type_is_host(lm_ggml_backend_buffer_type_t buft) {
4335
4438
  return true;
4336
4439
 
4337
- UNUSED(buft);
4440
+ LM_GGML_UNUSED(buft);
4338
4441
  }
4339
4442
 
4340
4443
  lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void) {
@@ -4357,7 +4460,7 @@ lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void) {
4357
4460
  static const char * lm_ggml_backend_metal_buffer_from_ptr_type_get_name(lm_ggml_backend_buffer_type_t buft) {
4358
4461
  return "Metal_Mapped";
4359
4462
 
4360
- UNUSED(buft);
4463
+ LM_GGML_UNUSED(buft);
4361
4464
  }
4362
4465
 
4363
4466
  static lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_from_ptr_type(void) {
@@ -4400,7 +4503,8 @@ lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size
4400
4503
  size_aligned += (size_page - (size_aligned % size_page));
4401
4504
  }
4402
4505
 
4403
- id<MTLDevice> device = lm_ggml_backend_metal_device_acq(&g_lm_ggml_ctx_dev_main);
4506
+ struct lm_ggml_backend_metal_device_context * ctx_dev = &g_lm_ggml_ctx_dev_main;
4507
+ id<MTLDevice> device = lm_ggml_backend_metal_device_acq(ctx_dev);
4404
4508
 
4405
4509
  // the buffer fits into the max buffer size allowed by the device
4406
4510
  if (size_aligned <= device.maxBufferLength) {
@@ -4453,6 +4557,13 @@ lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size
4453
4557
  }
4454
4558
  }
4455
4559
 
4560
+ if (!lm_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
4561
+ LM_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
4562
+ free(ctx);
4563
+ lm_ggml_backend_metal_device_rel(ctx_dev);
4564
+ return NULL;
4565
+ }
4566
+
4456
4567
  return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_from_ptr_type(), lm_ggml_backend_metal_buffer_i, ctx, size);
4457
4568
  }
4458
4569
 
@@ -4461,7 +4572,7 @@ lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size
4461
4572
  static const char * lm_ggml_backend_metal_name(lm_ggml_backend_t backend) {
4462
4573
  return "Metal";
4463
4574
 
4464
- UNUSED(backend);
4575
+ LM_GGML_UNUSED(backend);
4465
4576
  }
4466
4577
 
4467
4578
  static void lm_ggml_backend_metal_free(lm_ggml_backend_t backend) {
@@ -4766,6 +4877,13 @@ static lm_ggml_backend_buffer_t lm_ggml_backend_metal_device_buffer_from_ptr(lm_
4766
4877
  }
4767
4878
  }
4768
4879
 
4880
+ if (!lm_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
4881
+ LM_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
4882
+ free(ctx);
4883
+ lm_ggml_backend_metal_device_rel(ctx_dev);
4884
+ return NULL;
4885
+ }
4886
+
4769
4887
  return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_from_ptr_type(), lm_ggml_backend_metal_buffer_i, ctx, size);
4770
4888
  }
4771
4889
 
@@ -4779,7 +4897,7 @@ static bool lm_ggml_backend_metal_device_supports_buft(lm_ggml_backend_dev_t dev
4779
4897
  return buft->iface.get_name == lm_ggml_backend_metal_buffer_type_get_name ||
4780
4898
  buft->iface.get_name == lm_ggml_backend_metal_buffer_from_ptr_type_get_name;
4781
4899
 
4782
- UNUSED(dev);
4900
+ LM_GGML_UNUSED(dev);
4783
4901
  }
4784
4902
 
4785
4903
  static bool lm_ggml_backend_metal_device_offload_op(lm_ggml_backend_dev_t dev, const struct lm_ggml_tensor * op) {
package/cpp/ggml.c CHANGED
@@ -128,6 +128,10 @@ static void lm_ggml_print_backtrace_symbols(void) {
128
128
  #endif
129
129
 
130
130
  static void lm_ggml_print_backtrace(void) {
131
+ const char * LM_GGML_NO_BACKTRACE = getenv("LM_GGML_NO_BACKTRACE");
132
+ if (LM_GGML_NO_BACKTRACE) {
133
+ return;
134
+ }
131
135
  char attach[32];
132
136
  snprintf(attach, sizeof(attach), "attach %d", getpid());
133
137
  int pid = fork();
@@ -1388,7 +1392,7 @@ bool lm_ggml_are_same_stride(const struct lm_ggml_tensor * t0, const struct lm_g
1388
1392
  (t0->nb[3] == t1->nb[3]);
1389
1393
  }
1390
1394
 
1391
- // check if t1 can be represented as a repeatition of t0
1395
+ // check if t1 can be represented as a repetition of t0
1392
1396
  bool lm_ggml_can_repeat(const struct lm_ggml_tensor * t0, const struct lm_ggml_tensor * t1) {
1393
1397
  static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function");
1394
1398
 
@@ -5352,7 +5356,7 @@ static void lm_ggml_compute_backward(
5352
5356
  } break;
5353
5357
  case LM_GGML_OP_MUL: {
5354
5358
  if (src0_needs_grads) {
5355
- lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, src1, grad));
5359
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, lm_ggml_mul(ctx, grad, src1));
5356
5360
  }
5357
5361
  if (src1_needs_grads) {
5358
5362
  struct lm_ggml_tensor * tmp = lm_ggml_mul(ctx, src0, grad);
@@ -5444,21 +5448,25 @@ static void lm_ggml_compute_backward(
5444
5448
  // src1.shape [n,p,qq,rr]
5445
5449
 
5446
5450
  if (src0_needs_grads) {
5447
- struct lm_ggml_tensor * s1_tg =
5451
+ LM_GGML_ASSERT(grad->ne[2] == src1->ne[2]);
5452
+ LM_GGML_ASSERT(grad->ne[3] == src1->ne[3]);
5453
+ struct lm_ggml_tensor * tmp =
5448
5454
  lm_ggml_out_prod(ctx, // [n,m,qq,rr]
5449
5455
  src1, // [n,p,qq,rr]
5450
5456
  grad); // [m,p,qq,rr]
5451
- const int64_t qq = s1_tg->ne[2];
5452
- const int64_t rr = s1_tg->ne[3];
5453
- const int64_t q1 = src0->ne[2];
5454
- const int64_t r1 = src0->ne[3];
5455
- const bool ne2_broadcasted = qq > q1;
5456
- const bool ne3_broadcasted = rr > r1;
5457
- if (ne2_broadcasted || ne3_broadcasted) {
5458
- // sum broadcast repetitions of s1_tg into shape of src0
5459
- s1_tg = lm_ggml_repeat_back(ctx, s1_tg, src0);
5457
+ if (!lm_ggml_are_same_shape(tmp, src0)) {
5458
+ LM_GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
5459
+ LM_GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
5460
+ LM_GGML_ASSERT(tmp->ne[3] == 1);
5461
+
5462
+ const int64_t nr2 = tmp->ne[2] / src0->ne[2];
5463
+ const size_t nb2 = tmp->nb[2] * nr2;
5464
+ const size_t nb3 = tmp->nb[2];
5465
+
5466
+ tmp = lm_ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
5467
+ tmp = lm_ggml_repeat_back(ctx, tmp, src0);
5460
5468
  }
5461
- lm_ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
5469
+ lm_ggml_add_or_set(ctx, cgraph, isrc0, tmp);
5462
5470
  }
5463
5471
  if (src1_needs_grads) {
5464
5472
  lm_ggml_add_or_set(ctx, cgraph, isrc1,
@@ -5527,7 +5535,9 @@ static void lm_ggml_compute_backward(
5527
5535
  if (src0_needs_grads) {
5528
5536
  LM_GGML_ASSERT(!cgraph->grads[isrc0] || lm_ggml_is_contiguous(cgraph->grads[isrc0]));
5529
5537
  LM_GGML_ASSERT(lm_ggml_is_contiguous(grad));
5530
- lm_ggml_add_or_set(ctx, cgraph, isrc0, grad);
5538
+ LM_GGML_ASSERT(lm_ggml_nelements(tensor) == lm_ggml_nelements(src0));
5539
+ lm_ggml_add_or_set(ctx, cgraph, isrc0,
5540
+ lm_ggml_are_same_shape(tensor, src0) ? grad : lm_ggml_reshape(ctx, grad, src0));
5531
5541
  }
5532
5542
  } break;
5533
5543
  case LM_GGML_OP_RESHAPE: {
package/cpp/ggml.h CHANGED
@@ -198,7 +198,7 @@
198
198
 
199
199
  #ifndef __GNUC__
200
200
  # define LM_GGML_ATTRIBUTE_FORMAT(...)
201
- #elif defined(__MINGW32__)
201
+ #elif defined(__MINGW32__) && !defined(__clang__)
202
202
  # define LM_GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
203
203
  #else
204
204
  # define LM_GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
@@ -1776,7 +1776,7 @@ extern "C" {
1776
1776
  struct lm_ggml_tensor * a,
1777
1777
  int k);
1778
1778
 
1779
- #define LM_GGML_KQ_MASK_PAD 32
1779
+ #define LM_GGML_KQ_MASK_PAD 64
1780
1780
 
1781
1781
  // q: [n_embd, n_batch, n_head, 1]
1782
1782
  // k: [n_embd, n_kv, n_head_kv, 1]