@novastera-oss/llamarn 0.2.6 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (253) hide show
  1. package/android/src/main/cpp/include/llama.h +141 -38
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +58 -24
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +37 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +53 -40
  26. package/cpp/llama.cpp/common/common.h +6 -2
  27. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  28. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  29. package/cpp/llama.cpp/convert_hf_to_gguf.py +215 -76
  30. package/cpp/llama.cpp/ggml/CMakeLists.txt +48 -2
  31. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  32. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  33. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  34. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +64 -13
  35. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  38. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +124 -26
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +4 -3
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +93 -104
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +194 -69
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1158 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1571 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +213 -37
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +59 -37
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +90 -39
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  88. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  90. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  91. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +260 -49
  93. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +497 -282
  94. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1078 -468
  97. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  105. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +20 -48
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  110. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  111. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +117 -165
  112. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +192 -53
  113. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  115. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  116. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  117. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +8 -105
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +209 -92
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +36 -28
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +487 -247
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  138. package/cpp/llama.cpp/ggml/src/ggml.c +69 -19
  139. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  140. package/cpp/llama.cpp/gguf-py/gguf/constants.py +133 -0
  141. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +25 -1
  142. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +78 -3
  143. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  144. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  145. package/cpp/llama.cpp/include/llama.h +141 -38
  146. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  147. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  148. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  149. package/cpp/llama.cpp/src/llama-arch.cpp +150 -3
  150. package/cpp/llama.cpp/src/llama-arch.h +25 -1
  151. package/cpp/llama.cpp/src/llama-batch.cpp +736 -274
  152. package/cpp/llama.cpp/src/llama-batch.h +110 -57
  153. package/cpp/llama.cpp/src/llama-chat.cpp +30 -8
  154. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  155. package/cpp/llama.cpp/src/llama-context.cpp +360 -266
  156. package/cpp/llama.cpp/src/llama-context.h +27 -23
  157. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  158. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  159. package/cpp/llama.cpp/src/llama-graph.cpp +411 -344
  160. package/cpp/llama.cpp/src/llama-graph.h +126 -58
  161. package/cpp/llama.cpp/src/llama-hparams.cpp +10 -2
  162. package/cpp/llama.cpp/src/llama-hparams.h +16 -2
  163. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +103 -73
  164. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +34 -42
  165. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +345 -221
  166. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +75 -50
  167. package/cpp/llama.cpp/src/llama-kv-cells.h +51 -22
  168. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
  169. package/cpp/llama.cpp/src/llama-memory-hybrid.h +138 -0
  170. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +302 -317
  171. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +60 -68
  172. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  173. package/cpp/llama.cpp/src/llama-memory.h +73 -36
  174. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  175. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  176. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  177. package/cpp/llama.cpp/src/llama-model.cpp +1630 -511
  178. package/cpp/llama.cpp/src/llama-model.h +26 -0
  179. package/cpp/llama.cpp/src/llama-quant.cpp +89 -6
  180. package/cpp/llama.cpp/src/llama-vocab.cpp +58 -26
  181. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  182. package/cpp/llama.cpp/src/llama.cpp +11 -7
  183. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  184. package/cpp/rn-completion.cpp +2 -2
  185. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  186. package/cpp/{rn-utils.hpp → rn-utils.h} +3 -0
  187. package/ios/include/chat.h +1 -1
  188. package/ios/include/common.h +6 -2
  189. package/ios/include/llama.h +141 -38
  190. package/ios/libs/llama.xcframework/Info.plist +15 -15
  191. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  192. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  193. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  194. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  195. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +141 -38
  196. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  197. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  198. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  199. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  200. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  201. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  202. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  203. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  204. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  205. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  206. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3624
  207. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  208. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  209. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +141 -38
  210. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  211. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +141 -38
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +141 -38
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  219. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  220. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  221. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  222. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  223. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +141 -38
  224. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  225. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  226. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  227. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  228. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  229. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  231. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  232. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  233. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4725
  234. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  235. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  236. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +141 -38
  237. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  238. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  239. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4746
  240. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3652
  241. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  242. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  243. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  244. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  245. package/package.json +1 -2
  246. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  247. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  248. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  249. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  250. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  251. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  252. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  253. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context {
48
48
  int mtl_device_ref_count;
49
49
  id<MTLLibrary> mtl_library;
50
50
 
51
+ NSLock * mtl_lock;
52
+
51
53
  bool has_simdgroup_reduction;
52
54
  bool has_simdgroup_mm;
53
55
  bool has_residency_sets;
54
56
  bool has_bfloat;
55
57
  bool use_bfloat;
56
58
 
59
+ size_t max_size;
60
+
57
61
  char name[128];
58
62
  } g_ggml_ctx_dev_main = {
59
63
  /*.mtl_device =*/ nil,
60
64
  /*.mtl_device_ref_count =*/ 0,
61
65
  /*.mtl_library =*/ nil,
66
+ /*.mtl_lock =*/ nil,
62
67
  /*.has_simdgroup_reduction =*/ false,
63
68
  /*.has_simdgroup_mm =*/ false,
64
69
  /*.has_residency_sets =*/ false,
65
70
  /*.has_bfloat =*/ false,
66
71
  /*.use_bfloat =*/ false,
72
+ /*.max_size =*/ 0,
67
73
  /*.name =*/ "",
68
74
  };
69
75
 
@@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context {
71
77
  static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
72
78
  assert(ctx != NULL);
73
79
 
80
+ if (ctx->mtl_lock == nil) {
81
+ ctx->mtl_lock = [[NSLock alloc] init];
82
+ }
83
+
74
84
  if (ctx->mtl_device == nil) {
75
85
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
76
86
  }
@@ -94,6 +104,8 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
94
104
  ctx->use_bfloat = false;
95
105
  #endif
96
106
 
107
+ ctx->max_size = ctx->mtl_device.maxBufferLength;
108
+
97
109
  strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
98
110
  }
99
111
 
@@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
110
122
  ctx->mtl_device_ref_count--;
111
123
 
112
124
  if (ctx->mtl_device_ref_count == 0) {
125
+ if (ctx->mtl_lock) {
126
+ [ctx->mtl_lock release];
127
+ ctx->mtl_lock = nil;
128
+ }
129
+
113
130
  if (ctx->mtl_library) {
114
131
  [ctx->mtl_library release];
115
132
  ctx->mtl_library = nil;
@@ -185,6 +202,15 @@ enum ggml_metal_kernel_type {
185
202
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
186
203
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
187
204
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
208
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
209
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
210
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
211
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
212
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
188
214
  GGML_METAL_KERNEL_TYPE_RMS_NORM,
189
215
  GGML_METAL_KERNEL_TYPE_L2_NORM,
190
216
  GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -194,11 +220,14 @@ enum ggml_metal_kernel_type {
194
220
  GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
195
221
  GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
196
222
  GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
223
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
197
224
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
225
+ GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
198
226
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
199
227
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
200
228
  GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
201
229
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
230
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
202
231
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
203
232
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
204
233
  GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
@@ -498,6 +527,7 @@ enum ggml_metal_kernel_type {
498
527
  GGML_METAL_KERNEL_TYPE_COS,
499
528
  GGML_METAL_KERNEL_TYPE_NEG,
500
529
  GGML_METAL_KERNEL_TYPE_SUM_ROWS,
530
+ GGML_METAL_KERNEL_TYPE_MEAN,
501
531
  GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
502
532
  GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
503
533
  GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -976,7 +1006,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
976
1006
  struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
977
1007
  struct ggml_backend_metal_device_context * ctx_dev = dev->context;
978
1008
 
979
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
1009
+ id<MTLDevice> device = ctx_dev->mtl_device;
980
1010
 
981
1011
  GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
982
1012
 
@@ -990,9 +1020,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
990
1020
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
991
1021
 
992
1022
  // load library
993
- if (ctx_dev->mtl_library == nil) {
994
- ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
1023
+ {
1024
+ [ctx_dev->mtl_lock lock];
1025
+
1026
+ if (ctx_dev->mtl_library == nil) {
1027
+ ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
1028
+ }
1029
+
1030
+ [ctx_dev->mtl_lock unlock];
995
1031
  }
1032
+
996
1033
  id<MTLLibrary> metal_library = ctx_dev->mtl_library;
997
1034
  if (metal_library == nil) {
998
1035
  GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
@@ -1141,6 +1178,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1141
1178
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
1142
1179
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
1143
1180
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
1181
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
1182
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
1183
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
1184
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
1185
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
1186
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
1187
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
1188
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
1189
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
1144
1190
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1145
1191
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1146
1192
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1150,11 +1196,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1150
1196
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1151
1197
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1152
1198
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
1199
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
1153
1200
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
1201
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
1154
1202
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
1155
1203
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
1156
1204
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
1157
1205
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
1206
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
1158
1207
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
1159
1208
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
1160
1209
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
@@ -1454,6 +1503,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
1454
1503
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
1455
1504
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1456
1505
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1506
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1457
1507
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
1458
1508
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
1459
1509
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@@ -1603,6 +1653,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1603
1653
  const bool use_bfloat = ctx_dev->use_bfloat;
1604
1654
 
1605
1655
  if (!use_bfloat) {
1656
+ if (op->type == GGML_TYPE_BF16) {
1657
+ return false;
1658
+ }
1659
+
1606
1660
  for (size_t i = 0, n = 3; i < n; ++i) {
1607
1661
  if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
1608
1662
  return false;
@@ -1653,6 +1707,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1653
1707
  case GGML_OP_LOG:
1654
1708
  return false; // TODO: implement
1655
1709
  case GGML_OP_SUM_ROWS:
1710
+ case GGML_OP_MEAN:
1656
1711
  case GGML_OP_SOFT_MAX:
1657
1712
  case GGML_OP_GROUP_NORM:
1658
1713
  return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@@ -1771,6 +1826,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1771
1826
  {
1772
1827
  return op->ne[3] == 1;
1773
1828
  }
1829
+ case GGML_OP_SET_ROWS:
1830
+ {
1831
+ if (op->src[0]->type != GGML_TYPE_F32) {
1832
+ return false;
1833
+ }
1834
+
1835
+ switch (op->type) {
1836
+ case GGML_TYPE_F32:
1837
+ case GGML_TYPE_F16:
1838
+ case GGML_TYPE_BF16:
1839
+ case GGML_TYPE_Q8_0:
1840
+ case GGML_TYPE_Q4_0:
1841
+ case GGML_TYPE_Q4_1:
1842
+ case GGML_TYPE_Q5_0:
1843
+ case GGML_TYPE_Q5_1:
1844
+ case GGML_TYPE_IQ4_NL:
1845
+ return true;
1846
+ default:
1847
+ return false;
1848
+ };
1849
+ }
1774
1850
  default:
1775
1851
  return false;
1776
1852
  }
@@ -2400,11 +2476,31 @@ static bool ggml_metal_encode_node(
2400
2476
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2401
2477
  } break;
2402
2478
  case GGML_OP_SUM_ROWS:
2479
+ case GGML_OP_MEAN:
2403
2480
  {
2404
2481
  GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
2405
2482
 
2406
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
2483
+ id<MTLComputePipelineState> pipeline = nil;
2484
+
2485
+ switch (dst->op) {
2486
+ case GGML_OP_SUM_ROWS:
2487
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
2488
+ break;
2489
+ case GGML_OP_MEAN:
2490
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
2491
+ break;
2492
+ default:
2493
+ GGML_ABORT("fatal error");
2494
+ }
2407
2495
 
2496
+ int nth = 32; // SIMD width
2497
+
2498
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2499
+ nth *= 2;
2500
+ }
2501
+
2502
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
2503
+ nth = MIN(nth, ne00);
2408
2504
 
2409
2505
  ggml_metal_kargs_sum_rows args = {
2410
2506
  /*.ne00 =*/ ne00,
@@ -2434,11 +2530,12 @@ static bool ggml_metal_encode_node(
2434
2530
  };
2435
2531
 
2436
2532
  [encoder setComputePipelineState:pipeline];
2437
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2438
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2439
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
2533
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
2534
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2535
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2536
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
2440
2537
 
2441
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2538
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2442
2539
  } break;
2443
2540
  case GGML_OP_SOFT_MAX:
2444
2541
  {
@@ -3063,14 +3160,23 @@ static bool ggml_metal_encode_node(
3063
3160
  nsg = 1;
3064
3161
  nr0 = 1;
3065
3162
  nr1 = 4;
3066
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3163
+ if (ne00 == 4) {
3164
+ nr0 = 32;
3165
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
3166
+ } else {
3167
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3168
+ }
3067
3169
  } break;
3068
3170
  case GGML_TYPE_F16:
3069
3171
  {
3070
3172
  nsg = 1;
3071
3173
  nr0 = 1;
3072
3174
  if (src1t == GGML_TYPE_F32) {
3073
- if (ne11 * ne12 < 4) {
3175
+ if (ne00 == 4) {
3176
+ nr0 = 32;
3177
+ nr1 = 4;
3178
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
3179
+ } else if (ne11 * ne12 < 4) {
3074
3180
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
3075
3181
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3076
3182
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
@@ -3089,7 +3195,11 @@ static bool ggml_metal_encode_node(
3089
3195
  nsg = 1;
3090
3196
  nr0 = 1;
3091
3197
  if (src1t == GGML_TYPE_F32) {
3092
- if (ne11 * ne12 < 4) {
3198
+ if (ne00 == 4) {
3199
+ nr0 = 32;
3200
+ nr1 = 4;
3201
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
3202
+ } else if (ne11 * ne12 < 4) {
3093
3203
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
3094
3204
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3095
3205
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
@@ -3710,13 +3820,74 @@ static bool ggml_metal_encode_node(
3710
3820
  };
3711
3821
 
3712
3822
  [encoder setComputePipelineState:pipeline];
3713
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3714
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3715
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3716
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
3823
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3824
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3825
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3826
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3717
3827
 
3718
3828
  [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
3719
3829
  } break;
3830
+ case GGML_OP_SET_ROWS:
3831
+ {
3832
+ id<MTLComputePipelineState> pipeline = nil;
3833
+
3834
+ switch (dst->type) {
3835
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
3836
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
3837
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
3838
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
3839
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
3840
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
3841
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
3842
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
3843
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
3844
+ default: GGML_ABORT("not implemented");
3845
+ }
3846
+
3847
+ const int32_t nk0 = ne0/ggml_blck_size(dst->type);
3848
+
3849
+ int nth = 32; // SIMD width
3850
+
3851
+ while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3852
+ nth *= 2;
3853
+ }
3854
+
3855
+ int nrptg = 1;
3856
+ if (nth > nk0) {
3857
+ nrptg = (nth + nk0 - 1)/nk0;
3858
+ nth = nk0;
3859
+
3860
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
3861
+ nrptg--;
3862
+ }
3863
+ }
3864
+
3865
+ nth = MIN(nth, nk0);
3866
+
3867
+ ggml_metal_kargs_set_rows args = {
3868
+ /*.nk0 =*/ nk0,
3869
+ /*.ne01 =*/ ne01,
3870
+ /*.nb01 =*/ nb01,
3871
+ /*.nb02 =*/ nb02,
3872
+ /*.nb03 =*/ nb03,
3873
+ /*.ne11 =*/ ne11,
3874
+ /*.ne12 =*/ ne12,
3875
+ /*.nb10 =*/ nb10,
3876
+ /*.nb11 =*/ nb11,
3877
+ /*.nb12 =*/ nb12,
3878
+ /*.nb1 =*/ nb1,
3879
+ /*.nb2 =*/ nb2,
3880
+ /*.nb3 =*/ nb3,
3881
+ };
3882
+
3883
+ [encoder setComputePipelineState:pipeline];
3884
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3885
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3886
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3887
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3888
+
3889
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
3890
+ } break;
3720
3891
  case GGML_OP_RMS_NORM:
3721
3892
  {
3722
3893
  GGML_ASSERT(ne00 % 4 == 0);
@@ -3733,6 +3904,7 @@ static bool ggml_metal_encode_node(
3733
3904
  nth *= 2;
3734
3905
  }
3735
3906
 
3907
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3736
3908
  nth = MIN(nth, ne00/4);
3737
3909
 
3738
3910
  ggml_metal_kargs_rms_norm args = {
@@ -3769,6 +3941,7 @@ static bool ggml_metal_encode_node(
3769
3941
  nth *= 2;
3770
3942
  }
3771
3943
 
3944
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3772
3945
  nth = MIN(nth, ne00/4);
3773
3946
 
3774
3947
  ggml_metal_kargs_l2_norm args = {
@@ -3841,6 +4014,7 @@ static bool ggml_metal_encode_node(
3841
4014
  nth *= 2;
3842
4015
  }
3843
4016
 
4017
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3844
4018
  nth = MIN(nth, ne00/4);
3845
4019
 
3846
4020
  ggml_metal_kargs_norm args = {
@@ -4766,6 +4940,8 @@ static bool ggml_metal_encode_node(
4766
4940
  GGML_ASSERT(nqptg % 8 == 0);
4767
4941
  GGML_ASSERT(ncpsg % 32 == 0);
4768
4942
 
4943
+ const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
4944
+
4769
4945
  // 2*(2*ncpsg + nqptg)*(nsg)
4770
4946
  // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
4771
4947
  //
@@ -4773,7 +4949,7 @@ static bool ggml_metal_encode_node(
4773
4949
  // the shared memory needed for the simdgroups to load the KV cache
4774
4950
  // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
4775
4951
  //
4776
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
4952
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
4777
4953
 
4778
4954
  int64_t nsgmax = 2;
4779
4955
 
@@ -4810,9 +4986,9 @@ static bool ggml_metal_encode_node(
4810
4986
  // and store the soft_max values and the mask
4811
4987
  //
4812
4988
  // ne00*(nsg)
4813
- // each simdgroup has a full f16 head vector in shared mem to accumulate results
4989
+ // each simdgroup has a full f32 head vector in shared mem to accumulate results
4814
4990
  //
4815
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
4991
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
4816
4992
 
4817
4993
  int64_t nsgmax = 2;
4818
4994
  while (true) {
@@ -4925,8 +5101,39 @@ static bool ggml_metal_encode_node(
4925
5101
  default: GGML_ABORT("not implemented");
4926
5102
  }
4927
5103
 
5104
+ GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
5105
+
5106
+ // TODO: support
5107
+ //const int32_t nk00 = ne00/ggml_blck_size(dst->type);
5108
+ const int32_t nk00 = ne00;
5109
+
5110
+ int nth = 32; // SIMD width
5111
+
5112
+ while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
5113
+ nth *= 2;
5114
+ }
5115
+
5116
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
5117
+
5118
+ // when rows are small, we can batch them together in a single threadgroup
5119
+ int nrptg = 1;
5120
+
5121
+ // TODO: relax this constraint in the future
5122
+ if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
5123
+ if (nth > nk00) {
5124
+ nrptg = (nth + nk00 - 1)/nk00;
5125
+ nth = nk00;
5126
+
5127
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
5128
+ nrptg--;
5129
+ }
5130
+ }
5131
+ }
5132
+
5133
+ nth = MIN(nth, nk00);
5134
+
4928
5135
  ggml_metal_kargs_cpy args = {
4929
- /*.ne00 =*/ ne00,
5136
+ /*.ne00 =*/ nk00,
4930
5137
  /*.ne01 =*/ ne01,
4931
5138
  /*.ne02 =*/ ne02,
4932
5139
  /*.ne03 =*/ ne03,
@@ -4949,11 +5156,7 @@ static bool ggml_metal_encode_node(
4949
5156
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4950
5157
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
4951
5158
 
4952
- GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
4953
- int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
4954
-
4955
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4956
-
5159
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
4957
5160
  } break;
4958
5161
  case GGML_OP_SET:
4959
5162
  {
@@ -5259,7 +5462,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
5259
5462
  }
5260
5463
 
5261
5464
  ggml_backend_metal_buffer_rset_free(ctx);
5262
- ggml_backend_metal_device_rel(buffer->buft->device->context);
5263
5465
 
5264
5466
  if (ctx->owned) {
5265
5467
  #if TARGET_OS_OSX
@@ -5368,7 +5570,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
5368
5570
  }
5369
5571
 
5370
5572
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
5371
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5573
+
5574
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
5575
+
5576
+ id<MTLDevice> device = ctx_dev->mtl_device;
5372
5577
 
5373
5578
  ctx->all_data = ggml_metal_host_malloc(size_aligned);
5374
5579
  ctx->all_size = size_aligned;
@@ -5391,14 +5596,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
5391
5596
  if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
5392
5597
  GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
5393
5598
  free(ctx);
5394
- ggml_backend_metal_device_rel(ctx_dev);
5395
5599
  return NULL;
5396
5600
  }
5397
5601
 
5398
5602
  if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5399
5603
  GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5400
5604
  free(ctx);
5401
- ggml_backend_metal_device_rel(ctx_dev);
5402
5605
  return NULL;
5403
5606
  }
5404
5607
 
@@ -5409,17 +5612,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
5409
5612
 
5410
5613
  static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
5411
5614
  return 32;
5615
+
5412
5616
  GGML_UNUSED(buft);
5413
5617
  }
5414
5618
 
5415
5619
  static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
5416
- id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
5417
- const size_t max_size = device.maxBufferLength;
5418
- ggml_backend_metal_device_rel(buft->device->context);
5620
+ const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
5419
5621
 
5420
5622
  return max_size;
5421
-
5422
- GGML_UNUSED(buft);
5423
5623
  }
5424
5624
 
5425
5625
  static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
@@ -5492,7 +5692,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
5492
5692
  }
5493
5693
 
5494
5694
  struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
5495
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5695
+
5696
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
5697
+
5698
+ id<MTLDevice> device = ctx_dev->mtl_device;
5496
5699
 
5497
5700
  // the buffer fits into the max buffer size allowed by the device
5498
5701
  if (size_aligned <= device.maxBufferLength) {
@@ -5548,7 +5751,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
5548
5751
  if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5549
5752
  GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5550
5753
  free(ctx);
5551
- ggml_backend_metal_device_rel(ctx_dev);
5552
5754
  return NULL;
5553
5755
  }
5554
5756
 
@@ -5564,10 +5766,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
5564
5766
  }
5565
5767
 
5566
5768
  static void ggml_backend_metal_free(ggml_backend_t backend) {
5567
- struct ggml_backend_metal_context * ctx = backend->context;
5568
- struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
5769
+ struct ggml_backend_metal_context * ctx = backend->context;
5569
5770
 
5570
- ggml_backend_metal_device_rel(ctx_dev);
5571
5771
  ggml_metal_free(ctx);
5572
5772
 
5573
5773
  free(backend);
@@ -5707,6 +5907,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
5707
5907
 
5708
5908
  struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
5709
5909
 
5910
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
5911
+
5710
5912
  return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
5711
5913
  }
5712
5914
 
@@ -5726,10 +5928,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
5726
5928
  }
5727
5929
 
5728
5930
  static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
5729
- // acq/rel just to populate ctx->name in case it hasn't been done yet
5730
5931
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5731
- ggml_backend_metal_device_acq(ctx_dev);
5732
- ggml_backend_metal_device_rel(ctx_dev);
5733
5932
 
5734
5933
  return ctx_dev->name;
5735
5934
  }
@@ -5737,12 +5936,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t
5737
5936
  static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
5738
5937
  if (@available(macOS 10.12, iOS 16.0, *)) {
5739
5938
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5740
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
5939
+ id<MTLDevice> device = ctx_dev->mtl_device;
5741
5940
 
5742
5941
  *total = device.recommendedMaxWorkingSetSize;
5743
5942
  *free = *total - device.currentAllocatedSize;
5744
-
5745
- ggml_backend_metal_device_rel(ctx_dev);
5746
5943
  } else {
5747
5944
  *free = 1;
5748
5945
  *total = 1;
@@ -5820,7 +6017,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
5820
6017
  }
5821
6018
 
5822
6019
  struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
5823
- id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
6020
+
6021
+ GGML_ASSERT(ctx_dev->mtl_device != nil);
6022
+
6023
+ id<MTLDevice> device = ctx_dev->mtl_device;
5824
6024
 
5825
6025
  // the buffer fits into the max buffer size allowed by the device
5826
6026
  if (size_aligned <= device.maxBufferLength) {
@@ -5876,7 +6076,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
5876
6076
  if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5877
6077
  GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5878
6078
  free(ctx);
5879
- ggml_backend_metal_device_rel(ctx_dev);
5880
6079
  return NULL;
5881
6080
  }
5882
6081
 
@@ -5890,8 +6089,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
5890
6089
  }
5891
6090
 
5892
6091
  static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
5893
- return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
5894
- buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
6092
+ return
6093
+ buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
6094
+ buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
5895
6095
 
5896
6096
  GGML_UNUSED(dev);
5897
6097
  }
@@ -5976,8 +6176,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
5976
6176
  /* .get_proc_address = */ ggml_backend_metal_get_proc_address,
5977
6177
  };
5978
6178
 
6179
+ // called upon program exit
6180
+ static void ggml_metal_cleanup(void) {
6181
+ ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
6182
+ }
6183
+
6184
+ // TODO: make thread-safe
5979
6185
  ggml_backend_reg_t ggml_backend_metal_reg(void) {
5980
- // TODO: make this thread-safe somehow?
6186
+ ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
6187
+
6188
+ // register cleanup callback
6189
+ // TODO: not ideal, but not sure if there is a better way to do this in Objective-C
6190
+ atexit(ggml_metal_cleanup);
6191
+
5981
6192
  {
5982
6193
  g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
5983
6194
  /* .api_version = */ GGML_BACKEND_API_VERSION,