@novastera-oss/llamarn 0.2.5 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/RNLlamaCpp.podspec +3 -2
  2. package/android/CMakeLists.txt +6 -3
  3. package/android/src/main/cpp/include/llama.h +140 -38
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  12. package/cpp/LlamaCppModel.cpp +48 -67
  13. package/cpp/LlamaCppModel.h +8 -3
  14. package/cpp/PureCppImpl.cpp +1 -1
  15. package/cpp/PureCppImpl.h +2 -2
  16. package/cpp/build-info.cpp +2 -2
  17. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  18. package/cpp/llama.cpp/Makefile +2 -2
  19. package/cpp/llama.cpp/README.md +33 -13
  20. package/cpp/llama.cpp/common/CMakeLists.txt +15 -28
  21. package/cpp/llama.cpp/common/arg.cpp +38 -12
  22. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  23. package/cpp/llama.cpp/common/chat-parser.cpp +9 -3
  24. package/cpp/llama.cpp/common/chat-parser.h +4 -1
  25. package/cpp/llama.cpp/common/chat.cpp +16 -13
  26. package/cpp/llama.cpp/common/chat.h +1 -1
  27. package/cpp/llama.cpp/common/common.cpp +52 -40
  28. package/cpp/llama.cpp/common/common.h +5 -2
  29. package/cpp/llama.cpp/common/json-partial.cpp +5 -4
  30. package/cpp/llama.cpp/common/json-partial.h +2 -1
  31. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
  32. package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
  33. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  34. package/cpp/llama.cpp/convert_hf_to_gguf.py +128 -84
  35. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  36. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  37. package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
  38. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +49 -13
  39. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
  41. package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +33 -2
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +6 -8
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +25 -16
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  79. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  82. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  83. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  88. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  93. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  94. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -46
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  107. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  108. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -248
  109. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  110. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  112. package/cpp/llama.cpp/ggml/src/ggml.c +9 -8
  113. package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
  114. package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
  115. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  116. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  117. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  118. package/cpp/llama.cpp/include/llama.h +140 -38
  119. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  120. package/cpp/llama.cpp/src/CMakeLists.txt +4 -1
  121. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  122. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  123. package/cpp/llama.cpp/src/llama-batch.cpp +289 -31
  124. package/cpp/llama.cpp/src/llama-batch.h +47 -17
  125. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  126. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  127. package/cpp/llama.cpp/src/llama-context.cpp +488 -313
  128. package/cpp/llama.cpp/src/llama-context.h +38 -17
  129. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  130. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  131. package/cpp/llama.cpp/src/llama-graph.cpp +275 -152
  132. package/cpp/llama.cpp/src/llama-graph.h +109 -52
  133. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  134. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  135. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +281 -0
  136. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +133 -0
  137. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1835 -0
  138. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +308 -0
  139. package/cpp/llama.cpp/src/llama-kv-cells.h +53 -17
  140. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  141. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  142. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +1116 -0
  143. package/cpp/llama.cpp/src/llama-memory-recurrent.h +188 -0
  144. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  145. package/cpp/llama.cpp/src/llama-memory.h +89 -4
  146. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  147. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  148. package/cpp/llama.cpp/src/llama-model.cpp +735 -143
  149. package/cpp/llama.cpp/src/llama-model.h +4 -0
  150. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  151. package/cpp/llama.cpp/src/llama-vocab.cpp +39 -25
  152. package/cpp/llama.cpp/src/llama.cpp +11 -7
  153. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  154. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
  155. package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
  156. package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
  157. package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
  158. package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
  159. package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
  160. package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
  161. package/cpp/rn-completion.cpp +65 -10
  162. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  163. package/cpp/{rn-utils.hpp → rn-utils.h} +8 -1
  164. package/ios/include/chat.h +1 -1
  165. package/ios/include/common/minja/chat-template.hpp +1 -1
  166. package/ios/include/common/minja/minja.hpp +1 -1
  167. package/ios/include/common.h +5 -2
  168. package/ios/include/json-schema-to-grammar.h +4 -4
  169. package/ios/include/llama.h +140 -38
  170. package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
  171. package/ios/libs/llama.xcframework/Info.plist +20 -20
  172. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  173. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4617
  174. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
  175. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +140 -38
  176. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  177. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  178. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
  179. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3557
  180. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  181. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  182. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  184. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
  185. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3559
  186. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
  187. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +140 -38
  188. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
  189. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +140 -38
  190. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  191. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
  192. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +140 -38
  193. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  194. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  195. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  196. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4616
  197. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
  198. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +140 -38
  199. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  200. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  201. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4637
  202. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3556
  203. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  204. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  205. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  206. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  207. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4653
  208. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
  209. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +140 -38
  210. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  211. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  212. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4674
  213. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3587
  214. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
  215. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
  216. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  217. package/package.json +1 -2
  218. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  219. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  220. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  221. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2747
  222. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -502
  223. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  224. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  225. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -0,0 +1,98 @@
1
+ #pragma once
2
+
3
+ #define GGML_COMMON_DECL_CPP
4
+ #include "ggml-common.h"
5
+
6
+ #include "traits.h"
7
+ #include "ggml.h"
8
+
9
+ // GGML internal header
10
+
11
+ ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void);
12
+
13
+ template <int K> constexpr int QK_0() {
14
+ if constexpr (K == 4) {
15
+ return QK4_0;
16
+ }
17
+ if constexpr (K == 8) {
18
+ return QK8_0;
19
+ }
20
+ return -1;
21
+ }
22
+
23
+ template <int K, int N> struct block {
24
+ ggml_half d[N]; // deltas for N qK_0 blocks
25
+ int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
26
+ };
27
+
28
+ // control size
29
+ static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
30
+ static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
31
+ static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
32
+ static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
33
+
34
+ using block_q4_0x4 = block<4, 4>;
35
+ using block_q4_0x8 = block<4, 8>;
36
+ using block_q8_0x4 = block<8, 4>;
37
+ using block_q8_0x8 = block<8, 8>;
38
+
39
+ struct block_q4_Kx8 {
40
+ ggml_half d[8]; // super-block scale for quantized scales
41
+ ggml_half dmin[8]; // super-block scale for quantized mins
42
+ uint8_t scales[96]; // scales and mins, quantized with 6 bits
43
+ uint8_t qs[1024]; // 4--bit quants
44
+ };
45
+
46
+ static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
47
+
48
+ struct block_q8_Kx4 {
49
+ float d[4]; // delta
50
+ int8_t qs[QK_K * 4]; // quants
51
+ int16_t bsums[QK_K / 4]; // sum of quants in groups of 16
52
+ };
53
+
54
+ static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding");
55
+
56
+ struct block_iq4_nlx4 {
57
+ ggml_half d[4]; // deltas for 4 iq4_nl blocks
58
+ uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
59
+ };
60
+
61
+ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
62
+
63
+ #if defined(__cplusplus)
64
+ extern "C" {
65
+ #endif
66
+
67
+ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
68
+ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
69
+ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
70
+ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
71
+ void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
72
+ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
73
+ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
74
+ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
75
+ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
76
+ void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
77
+ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
78
+ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
79
+ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
80
+
81
+ // Native implementations
82
+ void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
83
+ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
84
+ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
85
+ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
86
+ void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
87
+ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
88
+ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
89
+ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
90
+ void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
91
+ void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
92
+ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
93
+ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
94
+ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
95
+
96
+ #if defined(__cplusplus)
97
+ } // extern "C"
98
+ #endif
@@ -944,10 +944,8 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
944
944
  for (int i = 0; i < offset; ++i) { \
945
945
  x[i] = vec_add(x[i], x[offset + i]); \
946
946
  } \
947
- res = vec_extract(x[0], 0) + \
948
- vec_extract(x[0], 1) + \
949
- vec_extract(x[0], 2) + \
950
- vec_extract(x[0], 3); \
947
+ float32x4_t tmp = x[0] + vec_reve(x[0]); \
948
+ res = tmp[0] + tmp[1]; \
951
949
  }
952
950
 
953
951
  #define GGML_F32_VEC GGML_F32x4
@@ -1,4 +1,4 @@
1
- #include "ggml-cpu-traits.h"
1
+ #include "traits.h"
2
2
 
3
3
  #include "ggml-backend-impl.h"
4
4
  #include "ggml-backend.h"
@@ -207,9 +207,9 @@ typedef float2 dfloat2;
207
207
  #define FP16_MMA_AVAILABLE
208
208
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
209
209
 
210
- #if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
210
+ #if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
211
211
  #define FP16_MMA_AVAILABLE
212
- #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
212
+ #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
213
213
 
214
214
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
215
215
  #define NEW_MMA_AVAILABLE
@@ -262,11 +262,11 @@ static bool cp_async_available(const int cc) {
262
262
  }
263
263
 
264
264
  static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
265
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
266
- return __AMDGCN_WAVEFRONT_SIZE;
265
+ #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
266
+ return 64;
267
267
  #else
268
268
  return 32;
269
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
269
+ #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
270
270
  }
271
271
 
272
272
  [[noreturn]]
@@ -466,9 +466,6 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
466
466
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
467
467
  }
468
468
 
469
- // TODO: move to ggml-common.h
470
- static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
471
-
472
469
  typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
473
470
 
474
471
  static __device__ __forceinline__ float get_alibi_slope(
@@ -635,6 +632,7 @@ struct ggml_cuda_device_info {
635
632
  int nsm; // number of streaming multiprocessors
636
633
  size_t smpb; // max. shared memory per block
637
634
  size_t smpbo; // max. shared memory per block (with opt-in)
635
+ bool integrated; // Device is integrated as opposed to discrete
638
636
  bool vmm; // virtual memory support
639
637
  size_t vmm_granularity; // granularity of virtual memory
640
638
  size_t total_vram;
@@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
652
652
  float KQ_max_scale[cols_per_thread];
653
653
  #pragma unroll
654
654
  for (int col = 0; col < cols_per_thread; ++col) {
655
- KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
655
+ const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
656
+ KQ_max_scale[col] = expf(KQ_max_diff);
656
657
  KQ_max[col] = KQ_max_new[col];
657
658
 
659
+ *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
660
+
658
661
  // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
659
662
  KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
660
663
  }
@@ -1246,7 +1249,7 @@ static __global__ void flash_attn_ext_f16(
1246
1249
  NO_DEVICE_CODE;
1247
1250
  return;
1248
1251
  }
1249
- #endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1252
+ #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1250
1253
 
1251
1254
  static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
1252
1255
 
@@ -243,10 +243,10 @@ static ggml_cuda_device_info ggml_cuda_init() {
243
243
 
244
244
  info.default_tensor_split[id] = total_vram;
245
245
  total_vram += prop.totalGlobalMem;
246
-
247
- info.devices[id].nsm = prop.multiProcessorCount;
248
- info.devices[id].smpb = prop.sharedMemPerBlock;
249
- info.devices[id].warp_size = prop.warpSize;
246
+ info.devices[id].integrated = prop.integrated;
247
+ info.devices[id].nsm = prop.multiProcessorCount;
248
+ info.devices[id].smpb = prop.sharedMemPerBlock;
249
+ info.devices[id].warp_size = prop.warpSize;
250
250
  #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
251
251
  info.devices[id].smpbo = prop.sharedMemPerBlock;
252
252
 
@@ -615,9 +615,8 @@ static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
615
615
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
616
616
 
617
617
  ggml_cuda_set_device(ctx->device);
618
- CUDA_CHECK(cudaDeviceSynchronize());
619
- CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size));
620
- CUDA_CHECK(cudaDeviceSynchronize());
618
+ CUDA_CHECK(cudaMemsetAsync(ctx->dev_ptr, value, buffer->size, cudaStreamPerThread));
619
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
621
620
  }
622
621
 
623
622
  static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
@@ -1065,6 +1064,10 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_
1065
1064
  GGML_UNUSED(buft);
1066
1065
  }
1067
1066
 
1067
+ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
1068
+ return buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
1069
+ }
1070
+
1068
1071
  static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1069
1072
  CUDA_CHECK(cudaFreeHost(buffer->context));
1070
1073
  }
@@ -1140,7 +1143,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)(
1140
1143
  static cudaError_t ggml_cuda_cpy_tensor_2d(
1141
1144
  void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
1142
1145
 
1143
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
1144
1146
  const char * src_ptr = (const char *) src->data;
1145
1147
  char * dst_ptr = (char *) dst;
1146
1148
 
@@ -1423,8 +1425,6 @@ static void ggml_cuda_op_mul_mat(
1423
1425
  const int64_t nb2 = dst->nb[2];
1424
1426
  const int64_t nb3 = dst->nb[3];
1425
1427
 
1426
- GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
1427
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
1428
1428
  ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
1429
1429
  ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
1430
1430
 
@@ -1746,7 +1746,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1746
1746
  GGML_ASSERT(!ggml_is_transposed(src0));
1747
1747
  GGML_ASSERT(!ggml_is_transposed(src1));
1748
1748
 
1749
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
1749
+ GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1750
1750
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
1751
1751
 
1752
1752
  // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
@@ -2641,6 +2641,8 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
2641
2641
 
2642
2642
  static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2643
2643
  bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
2644
+ // flag used to determine whether it is an integrated_gpu
2645
+ const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
2644
2646
 
2645
2647
  while (!graph_evaluated_or_captured) {
2646
2648
  // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
@@ -2659,10 +2661,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2659
2661
  if (node->src[j] != nullptr) {
2660
2662
  assert(node->src[j]->buffer);
2661
2663
  assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
2662
- ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
2664
+ ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
2663
2665
  }
2664
2666
  }
2665
- #endif
2667
+ #else
2668
+ GGML_UNUSED(integrated);
2669
+ #endif // NDEBUG
2666
2670
 
2667
2671
  bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
2668
2672
  if (!ok) {
@@ -2994,9 +2998,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
2994
2998
  {
2995
2999
  struct ggml_tensor * a = op->src[0];
2996
3000
  struct ggml_tensor * b = op->src[1];
2997
- // for small weight matrices the active device can end up without any rows, don't use row split in those cases
2998
- // this avoids some edge cases (and the performance would not be good anyways)
2999
3001
  if (a->buffer && ggml_backend_buft_is_cuda_split(a->buffer->buft)) {
3002
+ if (a->ne[2] > 1 || a->ne[3] > 1) {
3003
+ return false;
3004
+ }
3005
+ // for small weight matrices the active device can end up without any rows, don't use row split in those cases
3006
+ // this avoids some edge cases (and the performance would not be good anyways)
3000
3007
  ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) a->buffer->buft->context;
3001
3008
  int64_t row_low;
3002
3009
  int64_t row_high;
@@ -3263,7 +3270,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3263
3270
  }
3264
3271
 
3265
3272
  static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
3266
- return (ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev;
3273
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
3274
+ const bool integrated = ggml_cuda_info().devices[dev_ctx->device].integrated;
3275
+ return (((ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev) || (integrated && ggml_backend_buft_is_cuda_host(buft)));
3267
3276
  }
3268
3277
 
3269
3278
  static int64_t get_op_batch_size(const ggml_tensor * op) {
@@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
10
10
  float * __restrict__ dst, const int64_t L) {
11
11
  GGML_UNUSED(src1_nb0);
12
12
  GGML_UNUSED(src2_nb0);
13
+
14
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
13
15
  const int bidx = blockIdx.x; // split along B
14
16
  const int bidy = blockIdx.y; // split along D
15
17
  const int tid = threadIdx.x;
@@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
44
46
  if (N == 16) {
45
47
  #pragma unroll
46
48
  for (size_t i = 0; i < splitD / 4; i += 2) {
47
- float value = A_block[(wid * warpSize + i) * stride_A + wtid];
49
+ float value = A_block[(wid * warp_size + i) * stride_A + wtid];
48
50
  // todo: bank conflict
49
51
  // I am always confused with how to use the swizzling method to solve
50
52
  // bank conflit. Hoping somebody can tell me.
51
- smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
53
+ smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
52
54
  }
53
55
  #pragma unroll
54
56
  for (size_t i = 0; i < splitD / 4; i += 2) {
55
- float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
56
- smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
57
+ float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
58
+ smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
57
59
  }
58
60
  }
59
61
 
@@ -113,6 +113,10 @@ if (GGML_HIP_ROCWMMA_FATTN)
113
113
  add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
114
114
  endif()
115
115
 
116
+ if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
117
+ add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
118
+ endif()
119
+
116
120
  if (NOT GGML_CUDA_FA)
117
121
  add_compile_definitions(GGML_CUDA_NO_FA)
118
122
  endif()
@@ -32,6 +32,8 @@
32
32
  extern "C" {
33
33
  #endif
34
34
 
35
+ void ggml_print_backtrace(void);
36
+
35
37
  #ifndef MIN
36
38
  # define MIN(a, b) ((a) < (b) ? (a) : (b))
37
39
  #endif
@@ -44,21 +44,22 @@ if (GGML_METAL_EMBED_LIBRARY)
44
44
  set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
45
45
 
46
46
  add_custom_command(
47
- OUTPUT ${METALLIB_EMBED_ASM}
47
+ OUTPUT "${METALLIB_EMBED_ASM}"
48
48
  COMMAND echo "Embedding Metal library"
49
- COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED_TMP}
50
- COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}' -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED}
51
- COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM}
52
- COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM}
53
- COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM}
54
- COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM}
55
- COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM}
56
- COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM}
49
+ COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}"
50
+ COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}"
51
+ COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}"
52
+ COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}"
53
+ COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}"
54
+ COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}"
55
+ COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}"
56
+ COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}"
57
57
  DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
58
58
  COMMENT "Generate assembly for embedded Metal library"
59
+ VERBATIM
59
60
  )
60
61
 
61
- target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM})
62
+ target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
62
63
  else()
63
64
  if (GGML_METAL_SHADER_DEBUG)
64
65
  # custom command to do the following:
@@ -498,6 +498,7 @@ enum ggml_metal_kernel_type {
498
498
  GGML_METAL_KERNEL_TYPE_COS,
499
499
  GGML_METAL_KERNEL_TYPE_NEG,
500
500
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
501
+ GGML_METAL_KERNEL_TYPE_MEAN,
501
502
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
502
503
  GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
503
504
  GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -1454,6 +1455,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1454
1455
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1455
1456
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1456
1457
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1458
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1457
1459
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1458
1460
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
1459
1461
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1653
1655
  case GGML_OP_LOG:
1654
1656
  return false; // TODO: implement
1655
1657
  case GGML_OP_SUM_ROWS:
1658
+ case GGML_OP_MEAN:
1656
1659
  case GGML_OP_SOFT_MAX:
1657
1660
  case GGML_OP_GROUP_NORM:
1658
1661
  return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node(
2400
2403
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2401
2404
  } break;
2402
2405
  case GGML_OP_SUM_ROWS:
2406
+ case GGML_OP_MEAN:
2403
2407
  {
2404
2408
  GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
2405
2409
 
2406
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
2410
+ id<MTLComputePipelineState> pipeline = nil;
2411
+
2412
+ switch (dst->op) {
2413
+ case GGML_OP_SUM_ROWS:
2414
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
2415
+ break;
2416
+ case GGML_OP_MEAN:
2417
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
2418
+ break;
2419
+ default:
2420
+ GGML_ABORT("fatal error");
2421
+ }
2422
+
2423
+ int nth = 32; // SIMD width
2424
+
2425
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2426
+ nth *= 2;
2427
+ }
2407
2428
 
2429
+ nth = MIN(nth, ne00);
2408
2430
 
2409
2431
  ggml_metal_kargs_sum_rows args = {
2410
2432
  /*.ne00 =*/ ne00,
@@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node(
2434
2456
  };
2435
2457
 
2436
2458
  [encoder setComputePipelineState:pipeline];
2437
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2438
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2439
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
2459
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2460
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2461
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2462
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2440
2463
 
2441
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2464
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2442
2465
  } break;
2443
2466
  case GGML_OP_SOFT_MAX:
2444
2467
  {
@@ -4766,6 +4789,8 @@ static bool ggml_metal_encode_node(
4766
4789
  GGML_ASSERT(nqptg % 8 == 0);
4767
4790
  GGML_ASSERT(ncpsg % 32 == 0);
4768
4791
 
4792
+ const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
4793
+
4769
4794
  // 2*(2*ncpsg + nqptg)*(nsg)
4770
4795
  // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
4771
4796
  //
@@ -4773,7 +4798,7 @@ static bool ggml_metal_encode_node(
4773
4798
  // the shared memory needed for the simdgroups to load the KV cache
4774
4799
  // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
4775
4800
  //
4776
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
4801
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
4777
4802
 
4778
4803
  int64_t nsgmax = 2;
4779
4804
 
@@ -4810,9 +4835,9 @@ static bool ggml_metal_encode_node(
4810
4835
  // and store the soft_max values and the mask
4811
4836
  //
4812
4837
  // ne00*(nsg)
4813
- // each simdgroup has a full f16 head vector in shared mem to accumulate results
4838
+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
4814
4839
  //
4815
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
4840
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
4816
4841
 
4817
4842
  int64_t nsgmax = 2;
4818
4843
  while (true) {