@novastera-oss/llamarn 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. package/android/src/main/cpp/include/llama.h +134 -36
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +2 -2
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +30 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +50 -40
  26. package/cpp/llama.cpp/common/common.h +5 -2
  27. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  28. package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
  29. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
  30. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  31. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
  32. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  33. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  34. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  35. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  70. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  72. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
  74. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
  76. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  84. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  96. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
  97. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
  98. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
  99. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  101. package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
  102. package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
  103. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
  104. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
  105. package/cpp/llama.cpp/include/llama.h +134 -36
  106. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  107. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  108. package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
  109. package/cpp/llama.cpp/src/llama-arch.h +7 -1
  110. package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
  111. package/cpp/llama.cpp/src/llama-batch.h +36 -11
  112. package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
  113. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  114. package/cpp/llama.cpp/src/llama-context.cpp +313 -213
  115. package/cpp/llama.cpp/src/llama-context.h +16 -12
  116. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  117. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  118. package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
  119. package/cpp/llama.cpp/src/llama-graph.h +90 -34
  120. package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
  121. package/cpp/llama.cpp/src/llama-hparams.h +8 -2
  122. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
  123. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  124. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
  125. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
  126. package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
  127. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
  128. package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
  129. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
  130. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
  131. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  132. package/cpp/llama.cpp/src/llama-memory.h +64 -23
  133. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  134. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  135. package/cpp/llama.cpp/src/llama-model.cpp +726 -141
  136. package/cpp/llama.cpp/src/llama-model.h +4 -0
  137. package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
  138. package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
  139. package/cpp/llama.cpp/src/llama.cpp +11 -7
  140. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  141. package/cpp/rn-completion.cpp +2 -2
  142. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  143. package/ios/include/chat.h +1 -1
  144. package/ios/include/common.h +5 -2
  145. package/ios/include/llama.h +134 -36
  146. package/ios/libs/llama.xcframework/Info.plist +18 -18
  147. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  148. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  149. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
  150. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  151. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  152. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  153. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  154. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  155. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
  160. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
  161. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  162. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
  163. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  164. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  165. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
  167. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
  168. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  173. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  175. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  178. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  183. package/package.json +1 -2
  184. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  185. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  186. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  187. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  188. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  189. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  190. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  191. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  192. /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
@@ -0,0 +1,98 @@
1
+ #pragma once
2
+
3
+ #define GGML_COMMON_DECL_CPP
4
+ #include "ggml-common.h"
5
+
6
+ #include "traits.h"
7
+ #include "ggml.h"
8
+
9
+ // GGML internal header
10
+
11
+ ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void);
12
+
13
+ template <int K> constexpr int QK_0() {
14
+ if constexpr (K == 4) {
15
+ return QK4_0;
16
+ }
17
+ if constexpr (K == 8) {
18
+ return QK8_0;
19
+ }
20
+ return -1;
21
+ }
22
+
23
+ template <int K, int N> struct block {
24
+ ggml_half d[N]; // deltas for N qK_0 blocks
25
+ int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
26
+ };
27
+
28
+ // control size
29
+ static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
30
+ static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
31
+ static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
32
+ static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
33
+
34
+ using block_q4_0x4 = block<4, 4>;
35
+ using block_q4_0x8 = block<4, 8>;
36
+ using block_q8_0x4 = block<8, 4>;
37
+ using block_q8_0x8 = block<8, 8>;
38
+
39
+ struct block_q4_Kx8 {
40
+ ggml_half d[8]; // super-block scale for quantized scales
41
+ ggml_half dmin[8]; // super-block scale for quantized mins
42
+ uint8_t scales[96]; // scales and mins, quantized with 6 bits
43
+ uint8_t qs[1024]; // 4--bit quants
44
+ };
45
+
46
+ static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
47
+
48
+ struct block_q8_Kx4 {
49
+ float d[4]; // delta
50
+ int8_t qs[QK_K * 4]; // quants
51
+ int16_t bsums[QK_K / 4]; // sum of quants in groups of 16
52
+ };
53
+
54
+ static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding");
55
+
56
+ struct block_iq4_nlx4 {
57
+ ggml_half d[4]; // deltas for 4 iq4_nl blocks
58
+ uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
59
+ };
60
+
61
+ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
62
+
63
+ #if defined(__cplusplus)
64
+ extern "C" {
65
+ #endif
66
+
67
+ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
68
+ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
69
+ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
70
+ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
71
+ void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
72
+ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
73
+ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
74
+ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
75
+ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
76
+ void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
77
+ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
78
+ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
79
+ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
80
+
81
+ // Native implementations
82
+ void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
83
+ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
84
+ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
85
+ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
86
+ void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
87
+ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
88
+ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
89
+ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
90
+ void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
91
+ void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
92
+ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
93
+ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
94
+ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
95
+
96
+ #if defined(__cplusplus)
97
+ } // extern "C"
98
+ #endif
@@ -944,10 +944,8 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
944
944
  for (int i = 0; i < offset; ++i) { \
945
945
  x[i] = vec_add(x[i], x[offset + i]); \
946
946
  } \
947
- res = vec_extract(x[0], 0) + \
948
- vec_extract(x[0], 1) + \
949
- vec_extract(x[0], 2) + \
950
- vec_extract(x[0], 3); \
947
+ float32x4_t tmp = x[0] + vec_reve(x[0]); \
948
+ res = tmp[0] + tmp[1]; \
951
949
  }
952
950
 
953
951
  #define GGML_F32_VEC GGML_F32x4
@@ -1,4 +1,4 @@
1
- #include "ggml-cpu-traits.h"
1
+ #include "traits.h"
2
2
 
3
3
  #include "ggml-backend-impl.h"
4
4
  #include "ggml-backend.h"
@@ -207,9 +207,9 @@ typedef float2 dfloat2;
207
207
  #define FP16_MMA_AVAILABLE
208
208
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
209
209
 
210
- #if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
210
+ #if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
211
211
  #define FP16_MMA_AVAILABLE
212
- #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
212
+ #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
213
213
 
214
214
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
215
215
  #define NEW_MMA_AVAILABLE
@@ -262,11 +262,11 @@ static bool cp_async_available(const int cc) {
262
262
  }
263
263
 
264
264
  static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
265
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
266
- return __AMDGCN_WAVEFRONT_SIZE;
265
+ #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
266
+ return 64;
267
267
  #else
268
268
  return 32;
269
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
269
+ #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
270
270
  }
271
271
 
272
272
  [[noreturn]]
@@ -466,9 +466,6 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
466
466
  #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
467
467
  }
468
468
 
469
- // TODO: move to ggml-common.h
470
- static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
471
-
472
469
  typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
473
470
 
474
471
  static __device__ __forceinline__ float get_alibi_slope(
@@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
652
652
  float KQ_max_scale[cols_per_thread];
653
653
  #pragma unroll
654
654
  for (int col = 0; col < cols_per_thread; ++col) {
655
- KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
655
+ const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
656
+ KQ_max_scale[col] = expf(KQ_max_diff);
656
657
  KQ_max[col] = KQ_max_new[col];
657
658
 
659
+ *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
660
+
658
661
  // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
659
662
  KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
660
663
  }
@@ -615,9 +615,8 @@ static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
615
615
  ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
616
616
 
617
617
  ggml_cuda_set_device(ctx->device);
618
- CUDA_CHECK(cudaDeviceSynchronize());
619
- CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size));
620
- CUDA_CHECK(cudaDeviceSynchronize());
618
+ CUDA_CHECK(cudaMemsetAsync(ctx->dev_ptr, value, buffer->size, cudaStreamPerThread));
619
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
621
620
  }
622
621
 
623
622
  static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
@@ -1144,7 +1143,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)(
1144
1143
  static cudaError_t ggml_cuda_cpy_tensor_2d(
1145
1144
  void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
1146
1145
 
1147
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
1148
1146
  const char * src_ptr = (const char *) src->data;
1149
1147
  char * dst_ptr = (char *) dst;
1150
1148
 
@@ -1427,8 +1425,6 @@ static void ggml_cuda_op_mul_mat(
1427
1425
  const int64_t nb2 = dst->nb[2];
1428
1426
  const int64_t nb3 = dst->nb[3];
1429
1427
 
1430
- GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
1431
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
1432
1428
  ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
1433
1429
  ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
1434
1430
 
@@ -1750,7 +1746,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1750
1746
  GGML_ASSERT(!ggml_is_transposed(src0));
1751
1747
  GGML_ASSERT(!ggml_is_transposed(src1));
1752
1748
 
1753
- GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
1749
+ GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1754
1750
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
1755
1751
 
1756
1752
  // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
@@ -2668,7 +2664,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2668
2664
  ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
2669
2665
  }
2670
2666
  }
2671
- #endif
2667
+ #else
2668
+ GGML_UNUSED(integrated);
2669
+ #endif // NDEBUG
2672
2670
 
2673
2671
  bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
2674
2672
  if (!ok) {
@@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
10
10
  float * __restrict__ dst, const int64_t L) {
11
11
  GGML_UNUSED(src1_nb0);
12
12
  GGML_UNUSED(src2_nb0);
13
+
14
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
13
15
  const int bidx = blockIdx.x; // split along B
14
16
  const int bidy = blockIdx.y; // split along D
15
17
  const int tid = threadIdx.x;
@@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
44
46
  if (N == 16) {
45
47
  #pragma unroll
46
48
  for (size_t i = 0; i < splitD / 4; i += 2) {
47
- float value = A_block[(wid * warpSize + i) * stride_A + wtid];
49
+ float value = A_block[(wid * warp_size + i) * stride_A + wtid];
48
50
  // todo: bank conflict
49
51
  // I am always confused with how to use the swizzling method to solve
50
52
  // bank conflit. Hoping somebody can tell me.
51
- smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
53
+ smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
52
54
  }
53
55
  #pragma unroll
54
56
  for (size_t i = 0; i < splitD / 4; i += 2) {
55
- float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
56
- smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
57
+ float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
58
+ smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
57
59
  }
58
60
  }
59
61
 
@@ -113,6 +113,10 @@ if (GGML_HIP_ROCWMMA_FATTN)
113
113
  add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
114
114
  endif()
115
115
 
116
+ if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
117
+ add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
118
+ endif()
119
+
116
120
  if (NOT GGML_CUDA_FA)
117
121
  add_compile_definitions(GGML_CUDA_NO_FA)
118
122
  endif()
@@ -44,21 +44,22 @@ if (GGML_METAL_EMBED_LIBRARY)
44
44
  set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
45
45
 
46
46
  add_custom_command(
47
- OUTPUT ${METALLIB_EMBED_ASM}
47
+ OUTPUT "${METALLIB_EMBED_ASM}"
48
48
  COMMAND echo "Embedding Metal library"
49
- COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED_TMP}
50
- COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}' -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED}
51
- COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM}
52
- COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM}
53
- COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM}
54
- COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM}
55
- COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM}
56
- COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM}
49
+ COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}"
50
+ COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}"
51
+ COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}"
52
+ COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}"
53
+ COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}"
54
+ COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}"
55
+ COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}"
56
+ COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}"
57
57
  DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
58
58
  COMMENT "Generate assembly for embedded Metal library"
59
+ VERBATIM
59
60
  )
60
61
 
61
- target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM})
62
+ target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
62
63
  else()
63
64
  if (GGML_METAL_SHADER_DEBUG)
64
65
  # custom command to do the following:
@@ -498,6 +498,7 @@ enum ggml_metal_kernel_type {
498
498
  GGML_METAL_KERNEL_TYPE_COS,
499
499
  GGML_METAL_KERNEL_TYPE_NEG,
500
500
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
501
+ GGML_METAL_KERNEL_TYPE_MEAN,
501
502
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
502
503
  GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
503
504
  GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -1454,6 +1455,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1454
1455
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1455
1456
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1456
1457
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1458
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1457
1459
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1458
1460
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
1459
1461
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1653
1655
  case GGML_OP_LOG:
1654
1656
  return false; // TODO: implement
1655
1657
  case GGML_OP_SUM_ROWS:
1658
+ case GGML_OP_MEAN:
1656
1659
  case GGML_OP_SOFT_MAX:
1657
1660
  case GGML_OP_GROUP_NORM:
1658
1661
  return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node(
2400
2403
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2401
2404
  } break;
2402
2405
  case GGML_OP_SUM_ROWS:
2406
+ case GGML_OP_MEAN:
2403
2407
  {
2404
2408
  GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
2405
2409
 
2406
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
2410
+ id<MTLComputePipelineState> pipeline = nil;
2411
+
2412
+ switch (dst->op) {
2413
+ case GGML_OP_SUM_ROWS:
2414
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
2415
+ break;
2416
+ case GGML_OP_MEAN:
2417
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
2418
+ break;
2419
+ default:
2420
+ GGML_ABORT("fatal error");
2421
+ }
2422
+
2423
+ int nth = 32; // SIMD width
2424
+
2425
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2426
+ nth *= 2;
2427
+ }
2407
2428
 
2429
+ nth = MIN(nth, ne00);
2408
2430
 
2409
2431
  ggml_metal_kargs_sum_rows args = {
2410
2432
  /*.ne00 =*/ ne00,
@@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node(
2434
2456
  };
2435
2457
 
2436
2458
  [encoder setComputePipelineState:pipeline];
2437
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2438
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2439
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
2459
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2460
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2461
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2462
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2440
2463
 
2441
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2464
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2442
2465
  } break;
2443
2466
  case GGML_OP_SOFT_MAX:
2444
2467
  {
@@ -4766,6 +4789,8 @@ static bool ggml_metal_encode_node(
4766
4789
  GGML_ASSERT(nqptg % 8 == 0);
4767
4790
  GGML_ASSERT(ncpsg % 32 == 0);
4768
4791
 
4792
+ const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
4793
+
4769
4794
  // 2*(2*ncpsg + nqptg)*(nsg)
4770
4795
  // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
4771
4796
  //
@@ -4773,7 +4798,7 @@ static bool ggml_metal_encode_node(
4773
4798
  // the shared memory needed for the simdgroups to load the KV cache
4774
4799
  // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
4775
4800
  //
4776
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
4801
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
4777
4802
 
4778
4803
  int64_t nsgmax = 2;
4779
4804
 
@@ -4810,9 +4835,9 @@ static bool ggml_metal_encode_node(
4810
4835
  // and store the soft_max values and the mask
4811
4836
  //
4812
4837
  // ne00*(nsg)
4813
- // each simdgroup has a full f16 head vector in shared mem to accumulate results
4838
+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
4814
4839
  //
4815
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
4840
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
4816
4841
 
4817
4842
  int64_t nsgmax = 2;
4818
4843
  while (true) {