cui-llama.rn 1.6.0 → 1.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. package/README.md +35 -7
  2. package/android/src/main/CMakeLists.txt +16 -11
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +4 -1
  4. package/android/src/main/jni.cpp +20 -4
  5. package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
  9. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
  13. package/cpp/LICENSE +21 -0
  14. package/cpp/chat.cpp +1 -1
  15. package/cpp/common.cpp +17 -2
  16. package/cpp/common.h +7 -3
  17. package/cpp/ggml-alloc.c +4 -1
  18. package/cpp/ggml-cpp.h +1 -1
  19. package/cpp/ggml-cpu/amx/amx.cpp +221 -0
  20. package/cpp/ggml-cpu/amx/amx.h +8 -0
  21. package/cpp/ggml-cpu/amx/common.h +91 -0
  22. package/cpp/ggml-cpu/amx/mmq.cpp +2511 -0
  23. package/cpp/ggml-cpu/amx/mmq.h +10 -0
  24. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/binary-ops.h +1 -1
  25. package/cpp/ggml-cpu/common.h +72 -0
  26. package/cpp/{ggml-cpu-aarch64.cpp → ggml-cpu/ggml-cpu-aarch64.cpp} +809 -101
  27. package/cpp/{ggml-cpu.c → ggml-cpu/ggml-cpu.c} +109 -42
  28. package/cpp/{ggml-cpu.cpp → ggml-cpu/ggml-cpu.cpp} +3 -0
  29. package/cpp/{ops.cpp → ggml-cpu/ops.cpp} +246 -160
  30. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/ops.h +2 -20
  31. package/cpp/{sgemm.cpp → ggml-cpu/sgemm.cpp} +501 -0
  32. package/{ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers → cpp/ggml-cpu}/simd-mappings.h +7 -3
  33. package/{ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers → cpp/ggml-cpu}/unary-ops.h +1 -1
  34. package/cpp/ggml-cpu.h +5 -0
  35. package/cpp/ggml-impl.h +16 -9
  36. package/cpp/ggml-llama-sim.metallib +0 -0
  37. package/cpp/ggml-llama.metallib +0 -0
  38. package/cpp/ggml-metal.m +492 -47
  39. package/cpp/ggml.c +134 -244
  40. package/cpp/ggml.h +61 -94
  41. package/cpp/json-schema-to-grammar.cpp +3 -0
  42. package/cpp/llama-arch.cpp +46 -17
  43. package/cpp/llama-arch.h +9 -0
  44. package/cpp/llama-batch.cpp +5 -1
  45. package/cpp/llama-batch.h +2 -1
  46. package/cpp/llama-chat.cpp +31 -10
  47. package/cpp/llama-chat.h +3 -2
  48. package/cpp/llama-context.cpp +104 -489
  49. package/cpp/llama-context.h +14 -30
  50. package/cpp/llama-graph.cpp +69 -62
  51. package/cpp/llama-graph.h +21 -18
  52. package/cpp/llama-hparams.h +5 -0
  53. package/cpp/llama-kv-cache.cpp +1497 -391
  54. package/cpp/llama-kv-cache.h +272 -80
  55. package/cpp/llama-memory.h +11 -1
  56. package/cpp/llama-model.cpp +502 -176
  57. package/cpp/llama-model.h +13 -3
  58. package/cpp/llama-sampling.cpp +2 -1
  59. package/cpp/llama-vocab.cpp +8 -1
  60. package/cpp/llama.h +14 -11
  61. package/cpp/rn-llama.cpp +20 -172
  62. package/cpp/rn-llama.h +1 -5
  63. package/ios/CMakeLists.txt +13 -10
  64. package/ios/RNLlama.h +6 -0
  65. package/ios/RNLlama.mm +5 -0
  66. package/ios/RNLlamaContext.mm +26 -28
  67. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +7 -3
  68. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  69. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  70. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  71. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +61 -94
  72. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  73. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  74. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  75. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  76. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  77. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  78. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  79. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  80. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  81. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +14 -11
  82. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  83. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  84. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
  85. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  86. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  87. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  88. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  89. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  90. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  91. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  92. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  93. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  94. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  95. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  96. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  97. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  98. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  99. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  100. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  101. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  102. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  103. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +7 -3
  104. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpp.h +1 -1
  105. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +5 -0
  106. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +16 -9
  107. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +61 -94
  108. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +9 -0
  109. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +2 -1
  110. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +3 -2
  111. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +14 -30
  112. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +21 -18
  113. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +5 -0
  114. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  115. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +11 -1
  116. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +13 -3
  117. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +14 -11
  118. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +1 -5
  119. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
  120. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
  121. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +7 -3
  122. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpp.h +1 -1
  123. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +5 -0
  124. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +16 -9
  125. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +61 -94
  126. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +9 -0
  127. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +2 -1
  128. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +3 -2
  129. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +14 -30
  130. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +21 -18
  131. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +5 -0
  132. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +272 -80
  133. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +11 -1
  134. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +13 -3
  135. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +14 -11
  136. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +1 -5
  137. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
  138. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
  139. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  140. package/lib/module/NativeRNLlama.js.map +1 -1
  141. package/lib/typescript/NativeRNLlama.d.ts +4 -0
  142. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  143. package/package.json +1 -1
  144. package/src/NativeRNLlama.ts +5 -0
  145. package/cpp/binary-ops.h +0 -16
  146. package/cpp/ops.h +0 -128
  147. package/cpp/simd-mappings.h +0 -888
  148. package/cpp/unary-ops.h +0 -28
  149. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  150. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  151. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  152. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  153. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  154. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ops.h +0 -128
  155. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  156. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  157. package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/vec.h +0 -802
  158. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  159. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  160. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  161. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  162. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  163. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  164. package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  165. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/binary-ops.h +0 -16
  166. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  167. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  168. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  169. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  170. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ops.h +0 -128
  171. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/sgemm.h +0 -14
  172. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/simd-mappings.h +0 -888
  173. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/unary-ops.h +0 -28
  174. package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/vec.h +0 -802
  175. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/binary-ops.h +0 -16
  176. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-aarch64.h +0 -8
  177. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-impl.h +0 -512
  178. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-quants.h +0 -63
  179. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu-traits.h +0 -38
  180. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ops.h +0 -128
  181. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/sgemm.h +0 -14
  182. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/simd-mappings.h +0 -888
  183. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/unary-ops.h +0 -28
  184. package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/vec.h +0 -802
  185. /package/cpp/{binary-ops.cpp → ggml-cpu/binary-ops.cpp} +0 -0
  186. /package/cpp/{ggml-cpu-aarch64.h → ggml-cpu/ggml-cpu-aarch64.h} +0 -0
  187. /package/cpp/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -0
  188. /package/cpp/{ggml-cpu-quants.c → ggml-cpu/ggml-cpu-quants.c} +0 -0
  189. /package/cpp/{ggml-cpu-quants.h → ggml-cpu/ggml-cpu-quants.h} +0 -0
  190. /package/cpp/{ggml-cpu-traits.cpp → ggml-cpu/ggml-cpu-traits.cpp} +0 -0
  191. /package/cpp/{ggml-cpu-traits.h → ggml-cpu/ggml-cpu-traits.h} +0 -0
  192. /package/cpp/{sgemm.h → ggml-cpu/sgemm.h} +0 -0
  193. /package/cpp/{unary-ops.cpp → ggml-cpu/unary-ops.cpp} +0 -0
  194. /package/cpp/{vec.cpp → ggml-cpu/vec.cpp} +0 -0
  195. /package/cpp/{vec.h → ggml-cpu/vec.h} +0 -0
package/cpp/ggml-metal.m CHANGED
@@ -44,8 +44,8 @@ static struct lm_ggml_backend_device g_lm_ggml_backend_metal_device;
44
44
  // note: assumes single GPU device - the default one
45
45
  // TODO: support multiple GPU devices
46
46
  static struct lm_ggml_backend_metal_device_context {
47
- id<MTLDevice> mtl_device;
48
- int mtl_device_ref_count;
47
+ id<MTLDevice> mtl_device;
48
+ int mtl_device_ref_count;
49
49
  id<MTLLibrary> mtl_library;
50
50
 
51
51
  bool has_simdgroup_reduction;
@@ -354,6 +354,7 @@ enum lm_ggml_metal_kernel_type {
354
354
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
355
355
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
356
356
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
357
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
357
358
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
358
359
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
359
360
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
@@ -362,6 +363,7 @@ enum lm_ggml_metal_kernel_type {
362
363
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
363
364
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
364
365
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
366
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
365
367
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
366
368
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
367
369
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -370,6 +372,7 @@ enum lm_ggml_metal_kernel_type {
370
372
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
371
373
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
372
374
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
375
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
373
376
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
374
377
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
375
378
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
@@ -378,6 +381,7 @@ enum lm_ggml_metal_kernel_type {
378
381
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
379
382
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
380
383
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
384
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
381
385
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
382
386
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
383
387
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
@@ -386,6 +390,7 @@ enum lm_ggml_metal_kernel_type {
386
390
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
387
391
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
388
392
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
393
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
389
394
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
390
395
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
391
396
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
@@ -394,6 +399,7 @@ enum lm_ggml_metal_kernel_type {
394
399
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
395
400
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
396
401
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
402
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
397
403
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
398
404
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
399
405
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
@@ -402,6 +408,14 @@ enum lm_ggml_metal_kernel_type {
402
408
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
403
409
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
404
410
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
411
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
412
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
413
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
414
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
415
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
416
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
417
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
418
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
405
419
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
406
420
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
407
421
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
@@ -430,6 +444,13 @@ enum lm_ggml_metal_kernel_type {
430
444
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
431
445
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
432
446
  LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
447
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
448
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
449
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
450
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
451
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
452
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
453
+ LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
433
454
  LM_GGML_METAL_KERNEL_TYPE_SET_I32,
434
455
  LM_GGML_METAL_KERNEL_TYPE_SET_F32,
435
456
  LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -460,6 +481,7 @@ enum lm_ggml_metal_kernel_type {
460
481
  LM_GGML_METAL_KERNEL_TYPE_SQRT,
461
482
  LM_GGML_METAL_KERNEL_TYPE_SIN,
462
483
  LM_GGML_METAL_KERNEL_TYPE_COS,
484
+ LM_GGML_METAL_KERNEL_TYPE_NEG,
463
485
  LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
464
486
  LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
465
487
  LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -468,7 +490,259 @@ enum lm_ggml_metal_kernel_type {
468
490
  LM_GGML_METAL_KERNEL_TYPE_COUNT
469
491
  };
470
492
 
493
+ //
494
+ // lm_ggml_metal_heap
495
+ //
496
+
497
+ struct lm_ggml_metal_heap {
498
+ // number of times the heap was unused
499
+ int n_unused;
500
+
501
+ // total number of buffer allocations in this heap across all computes
502
+ int64_t n_alloc;
503
+
504
+ // current offset in the heap - we reset this after each node in order to reuse the memory
505
+ size_t offs;
506
+
507
+ // the currently allocated MTLBuffer objects in this heap
508
+ id<MTLHeap> obj;
509
+
510
+ NSMutableArray * bufs;
511
+ };
512
+
513
+ static struct lm_ggml_metal_heap * lm_ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
514
+ struct lm_ggml_metal_heap * heap = calloc(1, sizeof(struct lm_ggml_metal_heap));
515
+
516
+ MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
517
+ desc.storageMode = MTLStorageModePrivate;
518
+ desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
519
+ desc.type = MTLHeapTypePlacement;
520
+ desc.size = size;
521
+
522
+ heap->n_unused = 0;
523
+ heap->n_alloc = 0;
524
+
525
+ heap->obj = [device newHeapWithDescriptor:desc];
526
+ if (!heap->obj) {
527
+ LM_GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
528
+
529
+ free(heap);
530
+
531
+ return false;
532
+ }
533
+
534
+ [desc release];
535
+
536
+ heap->bufs = [[NSMutableArray alloc] init];
537
+
538
+ return heap;
539
+ }
540
+
541
+ static void lm_ggml_metal_heap_reset(struct lm_ggml_metal_heap * heap) {
542
+ heap->offs = 0;
543
+
544
+ // count how many graph computes the heap ended up being unused
545
+ if ([heap->bufs count] > 0) {
546
+ heap->n_unused = 0;
547
+ } else {
548
+ heap->n_unused++;
549
+ }
550
+
551
+ for (id<MTLBuffer> buf in heap->bufs) {
552
+ [buf release];
553
+ }
554
+ [heap->bufs removeAllObjects];
555
+
556
+ // tell the OS that it can reuse this memory if needed
557
+ // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
558
+ [heap->obj setPurgeableState:MTLPurgeableStateVolatile];
559
+ }
560
+
561
+ static void lm_ggml_metal_heap_free(struct lm_ggml_metal_heap * heap) {
562
+ if (heap == nil) {
563
+ return;
564
+ }
565
+
566
+ lm_ggml_metal_heap_reset(heap);
567
+
568
+ [heap->obj release];
569
+ [heap->bufs release];
570
+
571
+ free(heap);
572
+ }
573
+
574
+ @interface lm_ggml_metal_heap_ptr : NSObject
575
+
576
+ @property (nonatomic, assign) struct lm_ggml_metal_heap * data;
577
+
578
+ @end
579
+
580
+ @implementation lm_ggml_metal_heap_ptr
581
+ @end
582
+
583
+ //
584
+ // lm_ggml_metal_mem_pool
585
+ //
586
+
587
+ struct lm_ggml_metal_mem_pool {
588
+ id<MTLDevice> device;
589
+
590
+ int n_heaps; // total number of heaps ever created (including those that were removed)
591
+
592
+ NSMutableArray * heaps;
593
+ NSMutableArray * heaps_to_remove;
594
+ };
595
+
596
+ static struct lm_ggml_metal_mem_pool * lm_ggml_metal_mem_pool_init(void) {
597
+ struct lm_ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct lm_ggml_metal_mem_pool));
598
+
599
+ mem_pool->n_heaps = 0;
600
+
601
+ mem_pool->heaps = [[NSMutableArray alloc] init];
602
+ mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
603
+
604
+ return mem_pool;
605
+ }
606
+
607
+ static void lm_ggml_metal_mem_pool_free(struct lm_ggml_metal_mem_pool * mem_pool) {
608
+ LM_GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
609
+
610
+ size_t size_all = 0;
611
+ size_t size_cur = 0;
612
+
613
+ for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
614
+ LM_GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
615
+ LM_GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
616
+ LM_GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
617
+ LM_GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
618
+ LM_GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
619
+
620
+ if ([ptr.data->bufs count] > 0) {
621
+ size_cur += [ptr.data->obj size];
622
+ }
623
+ size_all += [ptr.data->obj size];
624
+
625
+ lm_ggml_metal_heap_free(ptr.data);
626
+ [ptr release];
627
+ }
628
+ [mem_pool->heaps release];
629
+ [mem_pool->heaps_to_remove release];
630
+
631
+ if (size_all > 0) {
632
+ LM_GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
633
+ LM_GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
634
+ }
635
+
636
+ free(mem_pool);
637
+ }
638
+
639
+ static void lm_ggml_metal_mem_pool_reset(struct lm_ggml_metal_mem_pool * mem_pool) {
640
+ for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
641
+ lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
642
+
643
+ struct lm_ggml_metal_heap * heap = ptr.data;
644
+ lm_ggml_metal_heap_reset(heap);
645
+
646
+ // if the heap hasn't been used for a while, remove it
647
+ if (heap->n_unused >= 128) {
648
+ [mem_pool->heaps_to_remove addObject:@(i)];
649
+ }
650
+ }
651
+
652
+ if (mem_pool->heaps_to_remove.count > 0) {
653
+ for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) {
654
+ NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
655
+ lm_ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
656
+
657
+ struct lm_ggml_metal_heap * heap = ptr.data;
658
+ lm_ggml_metal_heap_free(heap);
659
+
660
+ [mem_pool->heaps removeObjectAtIndex:index];
661
+ [ptr release];
662
+ }
663
+
664
+ [mem_pool->heaps_to_remove removeAllObjects];
665
+ }
666
+ }
667
+
668
+ static void lm_ggml_metal_mem_pool_clear(struct lm_ggml_metal_mem_pool * mem_pool) {
669
+ for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
670
+ ptr.data->offs = 0;
671
+ }
672
+ }
673
+
674
+ static id<MTLBuffer> lm_ggml_metal_mem_pool_alloc(struct lm_ggml_metal_mem_pool * mem_pool, size_t size) {
675
+ const size_t alignment = 32;
676
+
677
+ const size_t size_aligned = LM_GGML_PAD(size, alignment);
678
+
679
+ // try one of the existing heaps
680
+ for (lm_ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
681
+ struct lm_ggml_metal_heap * heap = ptr.data;
682
+ if (heap->offs + size_aligned <= [heap->obj size]) {
683
+ // if this is the first buffer in the heap for the current command buffer, tell the OS that
684
+ // it cannot free the memory used by the heap
685
+ // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
686
+ if ([heap->bufs count] == 0) {
687
+ [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
688
+ }
689
+
690
+ id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
691
+ if (buf == nil) {
692
+ LM_GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
693
+ return nil;
694
+ }
695
+
696
+ heap->n_alloc++;
697
+ heap->offs += size_aligned;
698
+
699
+ [heap->bufs addObject:buf];
700
+
701
+ return buf;
702
+ }
703
+ }
704
+
705
+ // create a new heap that can fit this buffer
706
+ lm_ggml_metal_heap_ptr * heap_ptr = [lm_ggml_metal_heap_ptr new];
707
+
708
+ struct lm_ggml_metal_heap * heap = lm_ggml_metal_heap_init(mem_pool->device, size_aligned);
709
+ if (heap == NULL) {
710
+ LM_GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
711
+ return NULL;
712
+ }
713
+
714
+ //LM_GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
715
+
716
+ heap_ptr.data = heap;
717
+ lm_ggml_metal_heap_reset(heap);
718
+
719
+ [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
720
+ id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
721
+ if (buf == nil) {
722
+ LM_GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
723
+ return NULL;
724
+ }
725
+
726
+ heap->n_alloc++;
727
+ heap->offs += size_aligned;
728
+
729
+ [heap->bufs addObject:buf];
730
+
731
+ [mem_pool->heaps addObject:heap_ptr];
732
+ mem_pool->n_heaps++;
733
+
734
+ return buf;
735
+ }
736
+
737
+ struct lm_ggml_metal_command_buffer {
738
+ id<MTLCommandBuffer> obj;
739
+
740
+ // each command buffer has a memory pool from which it can allocate temporary buffers during the compute
741
+ struct lm_ggml_metal_mem_pool * mem_pool;
742
+ };
743
+
471
744
  struct lm_ggml_backend_metal_context {
745
+ id<MTLDevice> device;
472
746
  id<MTLCommandQueue> queue;
473
747
 
474
748
  dispatch_queue_t d_queue;
@@ -493,7 +767,7 @@ struct lm_ggml_backend_metal_context {
493
767
  void (^encode_async)(size_t ith);
494
768
 
495
769
  // n_cb command buffers + 1 used by the main thread
496
- id<MTLCommandBuffer> command_buffers[LM_GGML_METAL_MAX_COMMAND_BUFFERS + 1];
770
+ struct lm_ggml_metal_command_buffer cmd_bufs[LM_GGML_METAL_MAX_COMMAND_BUFFERS + 1];
497
771
 
498
772
  // abort lm_ggml_metal_graph_compute if callback returns true
499
773
  lm_ggml_abort_callback abort_callback;
@@ -687,9 +961,11 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
687
961
  struct lm_ggml_backend_metal_device_context * ctx_dev = dev->context;
688
962
 
689
963
  id<MTLDevice> device = lm_ggml_backend_metal_device_acq(ctx_dev);
964
+
690
965
  LM_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
691
966
 
692
- ctx->queue = [device newCommandQueue];
967
+ ctx->device = device;
968
+ ctx->queue = [device newCommandQueue];
693
969
  if (ctx->queue == nil) {
694
970
  LM_GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
695
971
  return NULL;
@@ -750,7 +1026,10 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
750
1026
  ctx->gf = nil;
751
1027
  ctx->encode_async = nil;
752
1028
  for (int i = 0; i < LM_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
753
- ctx->command_buffers[i] = nil;
1029
+ ctx->cmd_bufs[i].obj = nil;
1030
+
1031
+ ctx->cmd_bufs[i].mem_pool = lm_ggml_metal_mem_pool_init();
1032
+ ctx->cmd_bufs[i].mem_pool->device = device;
754
1033
  }
755
1034
 
756
1035
  #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
@@ -1015,6 +1294,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1015
1294
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
1016
1295
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
1017
1296
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
1297
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
1018
1298
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
1019
1299
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
1020
1300
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
@@ -1023,6 +1303,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1023
1303
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
1024
1304
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
1025
1305
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
1306
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
1026
1307
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
1027
1308
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
1028
1309
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@@ -1031,6 +1312,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1031
1312
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
1032
1313
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
1033
1314
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
1315
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
1034
1316
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
1035
1317
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
1036
1318
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
@@ -1039,6 +1321,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1039
1321
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
1040
1322
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
1041
1323
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
1324
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
1042
1325
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
1043
1326
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
1044
1327
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
@@ -1047,6 +1330,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1047
1330
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
1048
1331
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
1049
1332
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
1333
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
1050
1334
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
1051
1335
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
1052
1336
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
@@ -1055,6 +1339,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1055
1339
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
1056
1340
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
1057
1341
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
1342
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
1058
1343
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
1059
1344
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
1060
1345
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
@@ -1063,6 +1348,14 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1063
1348
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
1064
1349
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
1065
1350
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
1351
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
1352
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
1353
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
1354
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
1355
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction);
1356
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction);
1357
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction);
1358
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction);
1066
1359
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
1067
1360
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
1068
1361
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
@@ -1091,6 +1384,13 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1091
1384
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
1092
1385
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
1093
1386
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
1387
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction);
1388
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat);
1389
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction);
1390
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction);
1391
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
1392
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
1393
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
1094
1394
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
1095
1395
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
1096
1396
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
@@ -1121,6 +1421,7 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
1121
1421
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
1122
1422
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
1123
1423
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
1424
+ LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1124
1425
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1125
1426
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1126
1427
  LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
@@ -1141,6 +1442,12 @@ static void lm_ggml_metal_free(struct lm_ggml_backend_metal_context * ctx) {
1141
1442
 
1142
1443
  [ctx->queue release];
1143
1444
 
1445
+ for (int i = 0; i < LM_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
1446
+ // ctx->cmd_bufs[i].obj is auto released
1447
+
1448
+ lm_ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
1449
+ }
1450
+
1144
1451
  dispatch_release(ctx->d_queue);
1145
1452
 
1146
1453
  free(ctx);
@@ -1282,6 +1589,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1282
1589
  case LM_GGML_UNARY_OP_GELU_QUICK:
1283
1590
  case LM_GGML_UNARY_OP_SILU:
1284
1591
  case LM_GGML_UNARY_OP_ELU:
1592
+ case LM_GGML_UNARY_OP_NEG:
1285
1593
  return lm_ggml_is_contiguous(op->src[0]) && op->src[0]->type == LM_GGML_TYPE_F32;
1286
1594
  default:
1287
1595
  return false;
@@ -1338,8 +1646,9 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1338
1646
  return op->src[0]->type == LM_GGML_TYPE_F16;
1339
1647
  case LM_GGML_OP_POOL_1D:
1340
1648
  return false;
1341
- case LM_GGML_OP_POOL_2D:
1342
1649
  case LM_GGML_OP_UPSCALE:
1650
+ return op->src[0]->type == LM_GGML_TYPE_F32 && op->op_params[0] == LM_GGML_SCALE_MODE_NEAREST;
1651
+ case LM_GGML_OP_POOL_2D:
1343
1652
  case LM_GGML_OP_PAD:
1344
1653
  case LM_GGML_OP_PAD_REFLECT_1D:
1345
1654
  case LM_GGML_OP_TIMESTEP_EMBEDDING:
@@ -1354,6 +1663,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1354
1663
  // TODO: not sure if it is worth adding kernels for this size
1355
1664
  return false;
1356
1665
  }
1666
+ if (op->src[0]->ne[0] == 576) {
1667
+ // DeepSeek sizes
1668
+ // TODO: disabled for now, until optmized
1669
+ return false;
1670
+ }
1357
1671
  if (op->src[1]->type != op->src[2]->type) {
1358
1672
  return false;
1359
1673
  }
@@ -1439,10 +1753,11 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
1439
1753
  }
1440
1754
  }
1441
1755
 
1442
- static void lm_ggml_metal_encode_node(
1756
+ static bool lm_ggml_metal_encode_node(
1443
1757
  lm_ggml_backend_t backend,
1444
1758
  int idx,
1445
- id<MTLComputeCommandEncoder> encoder) {
1759
+ id<MTLComputeCommandEncoder> encoder,
1760
+ struct lm_ggml_metal_mem_pool * mem_pool) {
1446
1761
  struct lm_ggml_backend_metal_context * ctx = backend->context;
1447
1762
  struct lm_ggml_backend_metal_device_context * ctx_dev = backend->device->context;
1448
1763
 
@@ -1458,7 +1773,7 @@ static void lm_ggml_metal_encode_node(
1458
1773
  struct lm_ggml_tensor * dst = node;
1459
1774
 
1460
1775
  if (lm_ggml_is_empty(dst)) {
1461
- return;
1776
+ return true;
1462
1777
  }
1463
1778
 
1464
1779
  switch (dst->op) {
@@ -1469,7 +1784,7 @@ static void lm_ggml_metal_encode_node(
1469
1784
  case LM_GGML_OP_PERMUTE:
1470
1785
  {
1471
1786
  // noop -> next node
1472
- } return;
1787
+ } return true;
1473
1788
  default:
1474
1789
  {
1475
1790
  } break;
@@ -1480,6 +1795,8 @@ static void lm_ggml_metal_encode_node(
1480
1795
  LM_GGML_ABORT("unsupported op");
1481
1796
  }
1482
1797
 
1798
+ lm_ggml_metal_mem_pool_clear(mem_pool);
1799
+
1483
1800
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
1484
1801
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
1485
1802
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -1966,6 +2283,18 @@ static void lm_ggml_metal_encode_node(
1966
2283
 
1967
2284
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1968
2285
  } break;
2286
+ case LM_GGML_UNARY_OP_NEG:
2287
+ {
2288
+ id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_NEG].pipeline;
2289
+
2290
+ [encoder setComputePipelineState:pipeline];
2291
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2292
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2293
+
2294
+ const int64_t n = lm_ggml_nelements(dst);
2295
+
2296
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2297
+ } break;
1969
2298
  default:
1970
2299
  {
1971
2300
  LM_GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op));
@@ -2114,26 +2443,76 @@ static void lm_ggml_metal_encode_node(
2114
2443
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2115
2444
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2116
2445
 
2117
- lm_ggml_metal_kargs_soft_max args = {
2446
+ // use this branch to test the lm_ggml_metal_mem_pool functionality
2447
+ #if 0
2448
+ // cpy to tmp buffer in MTLHeap
2449
+
2450
+ id<MTLBuffer> h_src0 = h_src0 = lm_ggml_metal_mem_pool_alloc(mem_pool, lm_ggml_nbytes(src0));
2451
+ if (!h_src0) {
2452
+ LM_GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, lm_ggml_nbytes(src0));
2453
+ return false;
2454
+ }
2455
+
2456
+ offs_src0 = 0;
2457
+
2458
+ lm_ggml_metal_kargs_cpy args_cpy = {
2118
2459
  /*.ne00 =*/ ne00,
2119
2460
  /*.ne01 =*/ ne01,
2120
2461
  /*.ne02 =*/ ne02,
2121
- /*.scale =*/ scale,
2122
- /*.max_bias =*/ max_bias,
2123
- /*.m0 =*/ m0,
2124
- /*.m1 =*/ m1,
2462
+ /*.ne03 =*/ ne03,
2463
+ /*.nb00 =*/ nb00,
2464
+ /*.nb01 =*/ nb01,
2465
+ /*.nb02 =*/ nb02,
2466
+ /*.nb03 =*/ nb03,
2467
+ /*.ne0 =*/ ne00,
2468
+ /*.ne1 =*/ ne01,
2469
+ /*.ne2 =*/ ne02,
2470
+ /*.ne3 =*/ ne03,
2471
+ /*.nb0 =*/ nb00,
2472
+ /*.nb1 =*/ nb01,
2473
+ /*.nb2 =*/ nb02,
2474
+ /*.nb3 =*/ nb03,
2475
+ };
2476
+
2477
+ if (src0->type == LM_GGML_TYPE_F16) {
2478
+ [encoder setComputePipelineState:ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
2479
+ } else {
2480
+ [encoder setComputePipelineState:ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
2481
+ }
2482
+ [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
2483
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2484
+ [encoder setBuffer:h_src0 offset:0 atIndex:2];
2485
+
2486
+ LM_GGML_ASSERT(ne00 % lm_ggml_blck_size(src0->type) == 0);
2487
+ int nth_cpy = MIN(1024, ne00 / lm_ggml_blck_size(src0->type));
2488
+
2489
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
2490
+
2491
+ #else
2492
+ id<MTLBuffer> h_src0 = id_src0;
2493
+ #endif
2494
+ // softmax
2495
+
2496
+ lm_ggml_metal_kargs_soft_max args = {
2497
+ /*.ne00 =*/ ne00,
2498
+ /*.ne01 =*/ ne01,
2499
+ /*.ne02 =*/ ne02,
2500
+ /*.scale =*/ scale,
2501
+ /*.max_bias =*/ max_bias,
2502
+ /*.m0 =*/ m0,
2503
+ /*.m1 =*/ m1,
2125
2504
  /*.n_head_log2 =*/ n_head_log2,
2126
2505
  };
2127
2506
 
2128
2507
  [encoder setComputePipelineState:pipeline];
2129
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2508
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
2130
2509
  if (id_src1) {
2131
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2510
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2132
2511
  } else {
2133
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2512
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
2134
2513
  }
2135
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2136
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
2514
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2515
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
2137
2516
 
2138
2517
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2139
2518
 
@@ -3846,12 +4225,14 @@ static void lm_ggml_metal_encode_node(
3846
4225
  // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
3847
4226
  // for now avoiding mainly to keep the number of templates/kernels a bit lower
3848
4227
  // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
3849
- if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) {
4228
+ if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
3850
4229
  switch (src1->type) {
3851
4230
  case LM_GGML_TYPE_F16:
3852
4231
  {
3853
4232
  if (ne00 == 192 && ne20 == 128) {
3854
4233
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
4234
+ } else if (ne00 == 576 && ne20 == 512) {
4235
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
3855
4236
  } else {
3856
4237
  switch (ne00) {
3857
4238
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
@@ -3874,6 +4255,8 @@ static void lm_ggml_metal_encode_node(
3874
4255
  {
3875
4256
  if (ne00 == 192 && ne20 == 128) {
3876
4257
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
4258
+ } else if (ne00 == 576 && ne20 == 512) {
4259
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
3877
4260
  } else {
3878
4261
  switch (ne00) {
3879
4262
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
@@ -3896,6 +4279,8 @@ static void lm_ggml_metal_encode_node(
3896
4279
  {
3897
4280
  if (ne00 == 192 && ne20 == 128) {
3898
4281
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
4282
+ } else if (ne00 == 576 && ne20 == 512) {
4283
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
3899
4284
  } else {
3900
4285
  switch (ne00) {
3901
4286
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
@@ -3918,6 +4303,8 @@ static void lm_ggml_metal_encode_node(
3918
4303
  {
3919
4304
  if (ne00 == 192 && ne20 == 128) {
3920
4305
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
4306
+ } else if (ne00 == 576 && ne20 == 512) {
4307
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
3921
4308
  } else {
3922
4309
  switch (ne00) {
3923
4310
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
@@ -3940,6 +4327,8 @@ static void lm_ggml_metal_encode_node(
3940
4327
  {
3941
4328
  if (ne00 == 192 && ne20 == 128) {
3942
4329
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
4330
+ } else if (ne00 == 576 && ne20 == 512) {
4331
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
3943
4332
  } else {
3944
4333
  switch (ne00) {
3945
4334
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
@@ -3962,6 +4351,8 @@ static void lm_ggml_metal_encode_node(
3962
4351
  {
3963
4352
  if (ne00 == 192 && ne20 == 128) {
3964
4353
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
4354
+ } else if (ne00 == 576 && ne20 == 512) {
4355
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
3965
4356
  } else {
3966
4357
  switch (ne00) {
3967
4358
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
@@ -3984,6 +4375,8 @@ static void lm_ggml_metal_encode_node(
3984
4375
  {
3985
4376
  if (ne00 == 192 && ne20 == 128) {
3986
4377
  pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
4378
+ } else if (ne00 == 576 && ne20 == 512) {
4379
+ pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
3987
4380
  } else {
3988
4381
  switch (ne00) {
3989
4382
  case 64: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
@@ -4013,6 +4406,24 @@ static void lm_ggml_metal_encode_node(
4013
4406
  use_vec_kernel = true;
4014
4407
 
4015
4408
  switch (ne00) {
4409
+ case 96:
4410
+ {
4411
+ switch (src1->type) {
4412
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break;
4413
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break;
4414
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break;
4415
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break;
4416
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break;
4417
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break;
4418
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break;
4419
+ default:
4420
+ {
4421
+ LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4422
+ LM_GGML_LOG_ERROR("add template specialization for this type\n");
4423
+ LM_GGML_ABORT("add template specialization for this type");
4424
+ }
4425
+ }
4426
+ } break;
4016
4427
  case 128:
4017
4428
  {
4018
4429
  switch (src1->type) {
@@ -4085,12 +4496,36 @@ static void lm_ggml_metal_encode_node(
4085
4496
  }
4086
4497
  }
4087
4498
  } break;
4499
+ case 576:
4500
+ {
4501
+ if (ne20 == 512) {
4502
+ switch (src1->type) {
4503
+ case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
4504
+ case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
4505
+ case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
4506
+ case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
4507
+ case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
4508
+ case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
4509
+ case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
4510
+ default:
4511
+ {
4512
+ LM_GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
4513
+ LM_GGML_LOG_ERROR("add template specialization for this type\n");
4514
+ LM_GGML_ABORT("add template specialization for this type");
4515
+ }
4516
+ }
4517
+ } else {
4518
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
4519
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
4520
+ LM_GGML_ABORT("add template specialization for this size");
4521
+ }
4522
+ } break;
4088
4523
  default:
4089
- {
4090
- LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4091
- LM_GGML_LOG_ERROR("add template specialization for this size\n");
4092
- LM_GGML_ABORT("add template specialization for this size");
4093
- }
4524
+ {
4525
+ LM_GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
4526
+ LM_GGML_LOG_ERROR("add template specialization for this size\n");
4527
+ LM_GGML_ABORT("add template specialization for this size");
4528
+ }
4094
4529
  }
4095
4530
  }
4096
4531
 
@@ -4486,6 +4921,8 @@ static void lm_ggml_metal_encode_node(
4486
4921
  LM_GGML_ABORT("fatal error");
4487
4922
  }
4488
4923
  }
4924
+
4925
+ return true;
4489
4926
  }
4490
4927
 
4491
4928
  static enum lm_ggml_status lm_ggml_metal_graph_compute(
@@ -4539,25 +4976,25 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
4539
4976
  }
4540
4977
 
4541
4978
  // the main thread commits the first few commands immediately
4542
- // command_buffer[n_cb]
4979
+ // cmd_buf[n_cb]
4543
4980
  {
4544
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4545
- ctx->command_buffers[n_cb] = command_buffer;
4981
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
4982
+ ctx->cmd_bufs[n_cb].obj = cmd_buf;
4546
4983
 
4547
- [command_buffer enqueue];
4984
+ [cmd_buf enqueue];
4548
4985
  ctx->encode_async(n_cb);
4549
4986
  }
4550
4987
 
4551
4988
  // prepare the rest of the command buffers asynchronously
4552
- // command_buffer[0.. n_cb)
4989
+ // cmd_buf[0.. n_cb)
4553
4990
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
4554
- id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4555
- ctx->command_buffers[cb_idx] = command_buffer;
4991
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
4992
+ ctx->cmd_bufs[cb_idx].obj = cmd_buf;
4556
4993
 
4557
4994
  // always enqueue the first two command buffers
4558
4995
  // enqueue all of the command buffers if we don't need to abort
4559
4996
  if (cb_idx < 2 || ctx->abort_callback == NULL) {
4560
- [command_buffer enqueue];
4997
+ [cmd_buf enqueue];
4561
4998
  }
4562
4999
  }
4563
5000
 
@@ -4566,14 +5003,14 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
4566
5003
  // wait for completion and check status of each command buffer
4567
5004
  // needed to detect if the device ran out-of-memory for example (#1881)
4568
5005
  {
4569
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
4570
- [command_buffer waitUntilCompleted];
5006
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
5007
+ [cmd_buf waitUntilCompleted];
4571
5008
 
4572
- MTLCommandBufferStatus status = [command_buffer status];
5009
+ MTLCommandBufferStatus status = [cmd_buf status];
4573
5010
  if (status != MTLCommandBufferStatusCompleted) {
4574
5011
  LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
4575
5012
  if (status == MTLCommandBufferStatusError) {
4576
- LM_GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
5013
+ LM_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
4577
5014
  }
4578
5015
 
4579
5016
  return LM_GGML_STATUS_FAILED;
@@ -4581,20 +5018,20 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
4581
5018
  }
4582
5019
 
4583
5020
  for (int i = 0; i < n_cb; ++i) {
4584
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
4585
- [command_buffer waitUntilCompleted];
5021
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
5022
+ [cmd_buf waitUntilCompleted];
4586
5023
 
4587
- MTLCommandBufferStatus status = [command_buffer status];
5024
+ MTLCommandBufferStatus status = [cmd_buf status];
4588
5025
  if (status != MTLCommandBufferStatusCompleted) {
4589
5026
  LM_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
4590
5027
  if (status == MTLCommandBufferStatusError) {
4591
- LM_GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
5028
+ LM_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
4592
5029
  }
4593
5030
 
4594
5031
  return LM_GGML_STATUS_FAILED;
4595
5032
  }
4596
5033
 
4597
- id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
5034
+ id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
4598
5035
  if (!next_buffer) {
4599
5036
  continue;
4600
5037
  }
@@ -4977,8 +5414,9 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
4977
5414
 
4978
5415
  const int n_nodes_per_cb = ctx->n_nodes_per_cb;
4979
5416
 
4980
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
4981
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
5417
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
5418
+
5419
+ id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
4982
5420
 
4983
5421
  int node_start = 0;
4984
5422
  int node_end = n_nodes_0;
@@ -4990,22 +5428,29 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
4990
5428
 
4991
5429
  const bool should_capture = ctx->capture_next_compute;
4992
5430
 
5431
+ struct lm_ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
5432
+ lm_ggml_metal_mem_pool_reset(mem_pool);
5433
+
4993
5434
  for (int idx = node_start; idx < node_end; ++idx) {
4994
5435
  if (should_capture) {
4995
5436
  [encoder pushDebugGroup:[NSString stringWithCString:lm_ggml_op_desc(lm_ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
4996
5437
  }
4997
5438
 
4998
- lm_ggml_metal_encode_node(backend, idx, encoder);
5439
+ const bool res = lm_ggml_metal_encode_node(backend, idx, encoder, mem_pool);
4999
5440
 
5000
5441
  if (should_capture) {
5001
5442
  [encoder popDebugGroup];
5002
5443
  }
5444
+
5445
+ if (!res) {
5446
+ break;
5447
+ }
5003
5448
  }
5004
5449
 
5005
5450
  [encoder endEncoding];
5006
5451
 
5007
5452
  if (cb_idx < 2 || ctx->abort_callback == NULL) {
5008
- [command_buffer commit];
5453
+ [cmd_buf commit];
5009
5454
  }
5010
5455
  });
5011
5456
  }