@novastera-oss/llamarn 0.2.6 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (253) hide show
  1. package/android/src/main/cpp/include/llama.h +141 -38
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +58 -24
  11. package/cpp/LlamaCppModel.h +3 -3
  12. package/cpp/PureCppImpl.cpp +1 -1
  13. package/cpp/PureCppImpl.h +2 -2
  14. package/cpp/build-info.cpp +2 -2
  15. package/cpp/llama.cpp/CMakeLists.txt +15 -4
  16. package/cpp/llama.cpp/Makefile +2 -2
  17. package/cpp/llama.cpp/README.md +32 -13
  18. package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
  19. package/cpp/llama.cpp/common/arg.cpp +37 -6
  20. package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
  21. package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
  22. package/cpp/llama.cpp/common/chat-parser.h +2 -0
  23. package/cpp/llama.cpp/common/chat.cpp +12 -9
  24. package/cpp/llama.cpp/common/chat.h +1 -1
  25. package/cpp/llama.cpp/common/common.cpp +53 -40
  26. package/cpp/llama.cpp/common/common.h +6 -2
  27. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  28. package/cpp/llama.cpp/common/speculative.cpp +6 -4
  29. package/cpp/llama.cpp/convert_hf_to_gguf.py +215 -76
  30. package/cpp/llama.cpp/ggml/CMakeLists.txt +48 -2
  31. package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
  32. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  33. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  34. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +64 -13
  35. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
  36. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  38. package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +124 -26
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +4 -3
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +93 -104
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +194 -69
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1158 -0
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1571 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +213 -37
  70. package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  71. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  72. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +59 -37
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +90 -39
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  88. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  89. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  90. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  91. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +260 -49
  93. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +497 -282
  94. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1078 -468
  97. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  100. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
  105. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +20 -48
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  110. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  111. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +117 -165
  112. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +192 -53
  113. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  115. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  116. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  117. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +8 -105
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +209 -92
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +36 -28
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +487 -247
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
  138. package/cpp/llama.cpp/ggml/src/ggml.c +69 -19
  139. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  140. package/cpp/llama.cpp/gguf-py/gguf/constants.py +133 -0
  141. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +25 -1
  142. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +78 -3
  143. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  144. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  145. package/cpp/llama.cpp/include/llama.h +141 -38
  146. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  147. package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
  148. package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
  149. package/cpp/llama.cpp/src/llama-arch.cpp +150 -3
  150. package/cpp/llama.cpp/src/llama-arch.h +25 -1
  151. package/cpp/llama.cpp/src/llama-batch.cpp +736 -274
  152. package/cpp/llama.cpp/src/llama-batch.h +110 -57
  153. package/cpp/llama.cpp/src/llama-chat.cpp +30 -8
  154. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  155. package/cpp/llama.cpp/src/llama-context.cpp +360 -266
  156. package/cpp/llama.cpp/src/llama-context.h +27 -23
  157. package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
  158. package/cpp/llama.cpp/src/llama-cparams.h +1 -1
  159. package/cpp/llama.cpp/src/llama-graph.cpp +411 -344
  160. package/cpp/llama.cpp/src/llama-graph.h +126 -58
  161. package/cpp/llama.cpp/src/llama-hparams.cpp +10 -2
  162. package/cpp/llama.cpp/src/llama-hparams.h +16 -2
  163. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +103 -73
  164. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +34 -42
  165. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +345 -221
  166. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +75 -50
  167. package/cpp/llama.cpp/src/llama-kv-cells.h +51 -22
  168. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
  169. package/cpp/llama.cpp/src/llama-memory-hybrid.h +138 -0
  170. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +302 -317
  171. package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +60 -68
  172. package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
  173. package/cpp/llama.cpp/src/llama-memory.h +73 -36
  174. package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
  175. package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
  176. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  177. package/cpp/llama.cpp/src/llama-model.cpp +1630 -511
  178. package/cpp/llama.cpp/src/llama-model.h +26 -0
  179. package/cpp/llama.cpp/src/llama-quant.cpp +89 -6
  180. package/cpp/llama.cpp/src/llama-vocab.cpp +58 -26
  181. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  182. package/cpp/llama.cpp/src/llama.cpp +11 -7
  183. package/cpp/llama.cpp/src/unicode.cpp +5 -0
  184. package/cpp/rn-completion.cpp +2 -2
  185. package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
  186. package/cpp/{rn-utils.hpp → rn-utils.h} +3 -0
  187. package/ios/include/chat.h +1 -1
  188. package/ios/include/common.h +6 -2
  189. package/ios/include/llama.h +141 -38
  190. package/ios/libs/llama.xcframework/Info.plist +15 -15
  191. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  192. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  193. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  194. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  195. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +141 -38
  196. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  197. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  198. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  199. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  200. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  201. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  202. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  203. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  204. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  205. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  206. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3624
  207. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  208. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  209. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +141 -38
  210. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  211. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  212. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +141 -38
  213. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +141 -38
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  219. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  220. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
  221. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  222. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  223. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +141 -38
  224. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  225. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  226. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
  227. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
  228. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  229. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  231. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  232. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  233. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4725
  234. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  235. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  236. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +141 -38
  237. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  238. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  239. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4746
  240. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3652
  241. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  242. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  243. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
  244. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  245. package/package.json +1 -2
  246. package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
  247. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  248. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
  249. package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
  250. package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
  251. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  252. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  253. /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
@@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
35
35
  -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
36
36
  };
37
37
 
38
+ static inline int best_index_int8(int n, constant float * val, float x) {
39
+ if (x <= val[0]) return 0;
40
+ if (x >= val[n-1]) return n-1;
41
+ int ml = 0, mu = n-1;
42
+ while (mu-ml > 1) {
43
+ int mav = (ml+mu)/2;
44
+ if (x < val[mav]) mu = mav; else ml = mav;
45
+ }
46
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
47
+ }
48
+
38
49
  // NOTE: this is not dequantizing - we are simply fitting the template
39
50
  template <typename type4x4>
40
51
  void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -97,6 +108,173 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
97
108
  }
98
109
  }
99
110
 
111
+ void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
112
+ float amax = 0.0f; // absolute max
113
+ float max = 0.0f;
114
+
115
+ for (int j = 0; j < QK4_0; j++) {
116
+ const float v = src[j];
117
+ if (amax < fabs(v)) {
118
+ amax = fabs(v);
119
+ max = v;
120
+ }
121
+ }
122
+
123
+ const float d = max / -8;
124
+ const float id = d ? 1.0f/d : 0.0f;
125
+
126
+ dst.d = d;
127
+
128
+ for (int j = 0; j < QK4_0/2; ++j) {
129
+ const float x0 = src[0 + j]*id;
130
+ const float x1 = src[QK4_0/2 + j]*id;
131
+
132
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
133
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
134
+
135
+ dst.qs[j] = xi0;
136
+ dst.qs[j] |= xi1 << 4;
137
+ }
138
+ }
139
+
140
+ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
141
+ float min = FLT_MAX;
142
+ float max = -FLT_MAX;
143
+
144
+ for (int j = 0; j < QK4_1; j++) {
145
+ const float v = src[j];
146
+ if (min > v) min = v;
147
+ if (max < v) max = v;
148
+ }
149
+
150
+ const float d = (max - min) / ((1 << 4) - 1);
151
+ const float id = d ? 1.0f/d : 0.0f;
152
+
153
+ dst.d = d;
154
+ dst.m = min;
155
+
156
+ for (int j = 0; j < QK4_1/2; ++j) {
157
+ const float x0 = (src[0 + j] - min)*id;
158
+ const float x1 = (src[QK4_1/2 + j] - min)*id;
159
+
160
+ const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
161
+ const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
162
+
163
+ dst.qs[j] = xi0;
164
+ dst.qs[j] |= xi1 << 4;
165
+ }
166
+ }
167
+
168
+ void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
169
+ float amax = 0.0f; // absolute max
170
+ float max = 0.0f;
171
+
172
+ for (int j = 0; j < QK5_0; j++) {
173
+ const float v = src[j];
174
+ if (amax < fabs(v)) {
175
+ amax = fabs(v);
176
+ max = v;
177
+ }
178
+ }
179
+
180
+ const float d = max / -16;
181
+ const float id = d ? 1.0f/d : 0.0f;
182
+
183
+ dst.d = d;
184
+
185
+ uint32_t qh = 0;
186
+ for (int j = 0; j < QK5_0/2; ++j) {
187
+ const float x0 = src[0 + j]*id;
188
+ const float x1 = src[QK5_0/2 + j]*id;
189
+
190
+ const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
191
+ const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
192
+
193
+ dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
194
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
195
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
196
+ }
197
+
198
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
199
+
200
+ for (int j = 0; j < 4; ++j) {
201
+ dst.qh[j] = qh8[j];
202
+ }
203
+ }
204
+
205
+ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
206
+ float max = src[0];
207
+ float min = src[0];
208
+
209
+ for (int j = 1; j < QK5_1; j++) {
210
+ const float v = src[j];
211
+ min = v < min ? v : min;
212
+ max = v > max ? v : max;
213
+ }
214
+
215
+ const float d = (max - min) / 31;
216
+ const float id = d ? 1.0f/d : 0.0f;
217
+
218
+ dst.d = d;
219
+ dst.m = min;
220
+
221
+ uint32_t qh = 0;
222
+ for (int j = 0; j < QK5_1/2; ++j) {
223
+ const float x0 = (src[0 + j] - min)*id;
224
+ const float x1 = (src[QK5_1/2 + j] - min)*id;
225
+
226
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
227
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
228
+
229
+ dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
230
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
231
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
232
+ }
233
+
234
+ thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
235
+
236
+ for (int j = 0; j < 4; ++j) {
237
+ dst.qh[j] = qh8[j];
238
+ }
239
+ }
240
+
241
+ void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
242
+ float amax = 0.0f; // absolute max
243
+ float max = 0.0f;
244
+
245
+ for (int j = 0; j < QK4_NL; j++) {
246
+ const float v = src[j];
247
+ if (amax < fabs(v)) {
248
+ amax = fabs(v);
249
+ max = v;
250
+ }
251
+ }
252
+
253
+ const float d = max / kvalues_iq4nl_f[0];
254
+ const float id = d ? 1.0f/d : 0.0f;
255
+
256
+ float sumqx = 0, sumq2 = 0;
257
+ for (int j = 0; j < QK4_NL/2; ++j) {
258
+ const float x0 = src[0 + j]*id;
259
+ const float x1 = src[QK4_NL/2 + j]*id;
260
+
261
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
262
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
263
+
264
+ dst.qs[j] = xi0 | (xi1 << 4);
265
+
266
+ const float v0 = kvalues_iq4nl_f[xi0];
267
+ const float v1 = kvalues_iq4nl_f[xi1];
268
+ const float w0 = src[0 + j]*src[0 + j];
269
+ const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
270
+ sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
271
+ sumq2 += w0*v0*v0 + w1*v1*v1;
272
+
273
+ }
274
+
275
+ dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
276
+ }
277
+
100
278
  template <typename type4x4>
101
279
  void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
102
280
  device const uint16_t * qs = ((device const uint16_t *)xb + 2);
@@ -279,6 +457,26 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
279
457
  }
280
458
  }
281
459
 
460
+ void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
461
+ float amax = 0.0f; // absolute max
462
+
463
+ for (int j = 0; j < QK8_0; j++) {
464
+ const float v = src[j];
465
+ amax = MAX(amax, fabs(v));
466
+ }
467
+
468
+ const float d = amax / ((1 << 7) - 1);
469
+ const float id = d ? 1.0f/d : 0.0f;
470
+
471
+ dst.d = d;
472
+
473
+ for (int j = 0; j < QK8_0; ++j) {
474
+ const float x0 = src[j]*id;
475
+
476
+ dst.qs[j] = round(x0);
477
+ }
478
+ }
479
+
282
480
  template <typename type4x4>
283
481
  void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
284
482
  const float d = xb->d;
@@ -993,31 +1191,61 @@ kernel void kernel_neg(
993
1191
  dst[tpig] = -src0[tpig];
994
1192
  }
995
1193
 
1194
+ template <bool norm>
996
1195
  kernel void kernel_sum_rows(
1196
+ constant ggml_metal_kargs_sum_rows & args,
997
1197
  device const float * src0,
998
1198
  device float * dst,
999
- constant ggml_metal_kargs_sum_rows & args,
1000
- uint3 tpig[[thread_position_in_grid]]) {
1001
- int64_t i3 = tpig.z;
1002
- int64_t i2 = tpig.y;
1003
- int64_t i1 = tpig.x;
1199
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
1200
+ uint3 tgpig[[threadgroup_position_in_grid]],
1201
+ ushort3 tpitg[[thread_position_in_threadgroup]],
1202
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1203
+ ushort tiisg[[thread_index_in_simdgroup]],
1204
+ ushort3 ntg[[threads_per_threadgroup]]) {
1205
+ int64_t i3 = tgpig.z;
1206
+ int64_t i2 = tgpig.y;
1207
+ int64_t i1 = tgpig.x;
1004
1208
 
1005
1209
  if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1006
1210
  return;
1007
1211
  }
1008
1212
 
1213
+ if (sgitg == 0) {
1214
+ shmem_f32[tiisg] = 0.0f;
1215
+ }
1216
+
1009
1217
  device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1010
1218
  device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1011
1219
 
1012
- float row_sum = 0;
1220
+ float sumf = 0;
1221
+
1222
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1223
+ sumf += src_row[i0];
1224
+ }
1225
+
1226
+ sumf = simd_sum(sumf);
1013
1227
 
1014
- for (int64_t i0 = 0; i0 < args.ne00; i0++) {
1015
- row_sum += src_row[i0];
1228
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1229
+
1230
+ if (tiisg == 0) {
1231
+ shmem_f32[sgitg] = sumf;
1016
1232
  }
1017
1233
 
1018
- dst_row[0] = row_sum;
1234
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1235
+
1236
+ sumf = shmem_f32[tiisg];
1237
+ sumf = simd_sum(sumf);
1238
+
1239
+ if (tpitg.x == 0) {
1240
+ dst_row[0] = norm ? sumf / args.ne00 : sumf;
1241
+ }
1019
1242
  }
1020
1243
 
1244
+ typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
1245
+
1246
+ template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
1247
+ template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
1248
+
1021
1249
  template<typename T>
1022
1250
  kernel void kernel_soft_max(
1023
1251
  device const char * src0,
@@ -2502,6 +2730,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<
2502
2730
  template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
2503
2731
  #endif
2504
2732
 
2733
+ template<typename T04, typename T14, typename args_t>
2734
+ void kernel_mul_mv_c4_impl(
2735
+ args_t args,
2736
+ device const char * src0,
2737
+ device const char * src1,
2738
+ device char * dst,
2739
+ uint3 tgpig,
2740
+ ushort tiisg) {
2741
+ const int r0 = tgpig.x*32 + tiisg;
2742
+ const int rb = tgpig.y*N_MV_T_T;
2743
+ const int im = tgpig.z;
2744
+
2745
+ if (r0 >= args.ne01) {
2746
+ return;
2747
+ }
2748
+
2749
+ const uint i12 = im%args.ne12;
2750
+ const uint i13 = im/args.ne12;
2751
+
2752
+ const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
2753
+
2754
+ device const T04 * x = (device const T04 *) (src0 + offset0);
2755
+
2756
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
2757
+
2758
+ for (int row = 0; row < N_MV_T_T; ++row) {
2759
+ int r1 = rb + row;
2760
+ if (r1 >= args.ne11) {
2761
+ break;
2762
+ }
2763
+
2764
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
2765
+
2766
+ device const T14 * y = (device const T14 *) (src1 + offset1);
2767
+
2768
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
2769
+ }
2770
+ }
2771
+
2772
+ template<typename T04, typename T14>
2773
+ kernel void kernel_mul_mv_c4(
2774
+ constant ggml_metal_kargs_mul_mv & args,
2775
+ device const char * src0,
2776
+ device const char * src1,
2777
+ device char * dst,
2778
+ uint3 tgpig[[threadgroup_position_in_grid]],
2779
+ ushort tiisg[[thread_index_in_simdgroup]]) {
2780
+ kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
2781
+ args,
2782
+ src0,
2783
+ src1,
2784
+ dst,
2785
+ tgpig,
2786
+ tiisg);
2787
+ }
2788
+
2789
+ typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
2790
+
2791
+ template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
2792
+ template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
2793
+ #if defined(GGML_METAL_USE_BF16)
2794
+ template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
2795
+ #endif
2796
+
2505
2797
  template<typename T, typename T4>
2506
2798
  kernel void kernel_mul_mv_1row(
2507
2799
  constant ggml_metal_kargs_mul_mv & args,
@@ -3328,14 +3620,12 @@ kernel void kernel_flash_attn_ext(
3328
3620
  constexpr short NW = N_SIMDWIDTH;
3329
3621
  constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
3330
3622
 
3331
- const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3332
- const short T = DK + 2*TS; // shared memory size per query in (half)
3623
+ const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3624
+ const short T = 2*DK + 2*TS; // shared memory size per query in (half)
3333
3625
 
3334
- threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3335
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3336
- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3337
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
3338
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
3626
+ threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3627
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3628
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
3339
3629
 
3340
3630
  threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
3341
3631
  threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@@ -3354,7 +3644,7 @@ kernel void kernel_flash_attn_ext(
3354
3644
  if (iq1 + j < args.ne01) {
3355
3645
  sq4[j*DK4 + i] = (q4_t) q4[i];
3356
3646
  } else {
3357
- sq4[j*DK4 + i] = (q4_t) 0.0f;
3647
+ sq4[j*DK4 + i] = 0;
3358
3648
  }
3359
3649
  }
3360
3650
  }
@@ -3548,20 +3838,20 @@ kernel void kernel_flash_attn_ext(
3548
3838
 
3549
3839
  // O = diag(ms)*O
3550
3840
  {
3551
- s8x8_t mm;
3552
- simdgroup_load(mm, ss + 2*C, TS, 0, false);
3841
+ s8x8_t ms;
3842
+ simdgroup_load(ms, ss + 2*C, TS, 0, false);
3553
3843
 
3554
3844
  #pragma unroll(DV8)
3555
3845
  for (short i = 0; i < DV8; ++i) {
3556
- simdgroup_multiply(lo[i], mm, lo[i]);
3846
+ simdgroup_multiply(lo[i], ms, lo[i]);
3557
3847
  }
3558
3848
  }
3559
3849
 
3560
3850
  // O = O + (Q*K^T)*V
3561
3851
  {
3562
3852
  for (short cc = 0; cc < C/8; ++cc) {
3563
- s8x8_t ms;
3564
- simdgroup_load(ms, ss + 8*cc, TS, 0, false);
3853
+ s8x8_t vs;
3854
+ simdgroup_load(vs, ss + 8*cc, TS, 0, false);
3565
3855
 
3566
3856
  if (is_same<vd4x4_t, v4x4_t>::value) {
3567
3857
  // we can read directly from global memory
@@ -3572,7 +3862,7 @@ kernel void kernel_flash_attn_ext(
3572
3862
  v8x8_t mv;
3573
3863
  simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
3574
3864
 
3575
- simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
3865
+ simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
3576
3866
  }
3577
3867
  } else {
3578
3868
  for (short ii = 0; ii < DV16; ii += 4) {
@@ -3593,10 +3883,10 @@ kernel void kernel_flash_attn_ext(
3593
3883
  v8x8_t mv;
3594
3884
 
3595
3885
  simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
3596
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3886
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
3597
3887
 
3598
3888
  simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
3599
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3889
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
3600
3890
  }
3601
3891
  } else {
3602
3892
  if (ii + tx < DV16) {
@@ -3611,10 +3901,10 @@ kernel void kernel_flash_attn_ext(
3611
3901
  v8x8_t mv;
3612
3902
 
3613
3903
  simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
3614
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3904
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
3615
3905
 
3616
3906
  simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
3617
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3907
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
3618
3908
  }
3619
3909
  }
3620
3910
  }
@@ -3624,93 +3914,89 @@ kernel void kernel_flash_attn_ext(
3624
3914
  }
3625
3915
 
3626
3916
  // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
3627
- for (short j = 0; j < Q; ++j) {
3628
- if (tiisg == 0) {
3629
- ss[j*TS + 0] = S[j];
3630
- ss[j*TS + 1] = M[j];
3631
- }
3917
+ for (short j = tiisg; j < Q; j += NW) {
3918
+ ss[j*TS + 0] = S[j];
3919
+ ss[j*TS + 1] = M[j];
3632
3920
  }
3633
3921
  }
3634
3922
 
3635
- // reduce the warps sequentially
3636
- for (ushort sg = 1; sg < nsg; ++sg) {
3637
- float S = { 0.0f };
3638
- float M = { -__FLT_MAX__/2 };
3923
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3639
3924
 
3640
- threadgroup_barrier(mem_flags::mem_threadgroup);
3925
+ threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3926
+ threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
3641
3927
 
3642
- // each simdgroup stores its output to shared memory, reusing sq
3643
- if (sgitg == sg) {
3644
- for (short i = 0; i < DV8; ++i) {
3645
- simdgroup_store(lo[i], so + i*8, DV, 0, false);
3646
- }
3928
+ // store result to shared memory in F32
3929
+ if (sgitg == 0) {
3930
+ for (short i = 0; i < DV8; ++i) {
3931
+ //simdgroup_store(lo[i], so + i*8, DV, 0, false);
3932
+ simdgroup_float8x8 t(1.0f);
3933
+ simdgroup_multiply(t, lo[i], t);
3934
+ simdgroup_store(t, so + i*8, DV, 0, false);
3647
3935
  }
3936
+ }
3648
3937
 
3649
- threadgroup_barrier(mem_flags::mem_threadgroup);
3938
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3650
3939
 
3651
- // the first simdgroup accumulates the results from the other simdgroups
3652
- if (sgitg == 0) {
3653
- for (short j = 0; j < Q; ++j) {
3654
- const float S0 = ss[j*TS + 0];
3655
- const float S1 = ss[j*TS + sg*SH + 0];
3940
+ // reduce the warps sequentially
3941
+ for (ushort sg = 1; sg < nsg; ++sg) {
3942
+ if (sgitg == sg) {
3943
+ for (short j = tiisg; j < Q; j += NW) {
3944
+ const float S0 = ss[j*TS - 1*SH + 0];
3945
+ const float S1 = ss[j*TS + 0];
3656
3946
 
3657
- const float M0 = ss[j*TS + 1];
3658
- const float M1 = ss[j*TS + sg*SH + 1];
3947
+ const float M0 = ss[j*TS - 1*SH + 1];
3948
+ const float M1 = ss[j*TS + 1];
3659
3949
 
3660
- M = max(M0, M1);
3950
+ const float M = max(M0, M1);
3661
3951
 
3662
- const float ms0 = exp(M0 - M);
3663
- const float ms1 = exp(M1 - M);
3952
+ float ms0 = exp(M0 - M);
3953
+ float ms1 = exp(M1 - M);
3664
3954
 
3665
- S = S0*ms0 + S1*ms1;
3955
+ const float S = S0*ms0 + S1*ms1;
3666
3956
 
3667
- if (tiisg == 0) {
3668
- ss[j*TS + 0] = S;
3669
- ss[j*TS + 1] = M;
3957
+ ss[j*TS + 0] = S;
3958
+ ss[j*TS + 1] = M;
3670
3959
 
3671
- ss[j*TS + 2*C + j ] = ms0;
3672
- ss[j*TS + 2*C + j + sg*SH] = ms1;
3673
- }
3960
+ ss[j*TS + 2*C + j - 1*SH] = ms0;
3961
+ ss[j*TS + 2*C + j ] = ms1;
3674
3962
  }
3675
3963
 
3964
+ //simdgroup_barrier(mem_flags::mem_threadgroup);
3965
+
3676
3966
  // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3677
3967
  {
3678
3968
  s8x8_t ms0;
3679
3969
  s8x8_t ms1;
3680
3970
 
3681
- simdgroup_load(ms0, ss + 2*C, TS, 0, false);
3682
- simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
3971
+ simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
3972
+ simdgroup_load(ms1, ss + 2*C, TS, 0, false);
3683
3973
 
3684
3974
  #pragma unroll(DV8)
3685
3975
  for (short i = 0; i < DV8; ++i) {
3686
- o8x8_t t;
3976
+ simdgroup_float8x8 t;
3687
3977
 
3688
3978
  simdgroup_load (t, so + i*8, DV, 0, false);
3689
- simdgroup_multiply(t, ms1, t);
3979
+ simdgroup_multiply(t, ms0, t);
3690
3980
 
3691
- simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
3981
+ simdgroup_multiply_accumulate(t, ms1, lo[i], t);
3982
+ simdgroup_store(t, so + i*8, DV, 0, false);
3692
3983
  }
3693
3984
  }
3694
3985
  }
3695
- }
3696
3986
 
3697
- // store result to shared memory (reuse sq)
3698
- if (sgitg == 0) {
3699
- for (short i = 0; i < DV8; ++i) {
3700
- simdgroup_store(lo[i], so + i*8, DV, 0, false);
3701
- }
3987
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3702
3988
  }
3703
3989
 
3704
- device float4 * dst4 = (device float4 *) dst;
3990
+ threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
3705
3991
 
3706
3992
  // final rescale with 1/S and store to global memory
3707
- if (sgitg == 0) {
3708
- for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
3709
- const float S = ss[j*TS + 0];
3993
+ for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
3994
+ const float S = 1.0f/sf[j*TS + 0];
3710
3995
 
3711
- for (short i = tiisg; i < DV4; i += NW) {
3712
- dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
3713
- }
3996
+ device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
3997
+
3998
+ for (short i = tiisg; i < DV4; i += NW) {
3999
+ dst4[i] = (float4) so4[j*DV4 + i]*S;
3714
4000
  }
3715
4001
  }
3716
4002
  }
@@ -3719,12 +4005,22 @@ kernel void kernel_flash_attn_ext(
3719
4005
  // template to be able to explore different combinations
3720
4006
  //
3721
4007
  #define FA_TYPES \
3722
- half, half4, simdgroup_half8x8, \
3723
- half, half4x4, simdgroup_half8x8, \
3724
- half, half4x4, simdgroup_half8x8, \
3725
- float, simdgroup_float8x8, \
3726
- float, simdgroup_float8x8, \
3727
- half, half4, simdgroup_half8x8
4008
+ float, float4, simdgroup_float8x8, \
4009
+ half, half4x4, simdgroup_half8x8, \
4010
+ half, half4x4, simdgroup_half8x8, \
4011
+ float, simdgroup_float8x8, \
4012
+ float, simdgroup_float8x8, \
4013
+ half, half4, simdgroup_half8x8
4014
+ //float, float4, simdgroup_float8x8
4015
+
4016
+ #define FA_TYPES_BF \
4017
+ bfloat, bfloat4, simdgroup_bfloat8x8, \
4018
+ bfloat, bfloat4x4, simdgroup_bfloat8x8, \
4019
+ bfloat, bfloat4x4, simdgroup_bfloat8x8, \
4020
+ float, simdgroup_float8x8, \
4021
+ float, simdgroup_float8x8, \
4022
+ half, half4, simdgroup_half8x8
4023
+ //float, float4, simdgroup_float8x8
3728
4024
 
3729
4025
  typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
3730
4026
 
@@ -3739,15 +4035,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
3739
4035
  template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
3740
4036
 
3741
4037
  #if defined(GGML_METAL_USE_BF16)
3742
- template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
3743
- template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
3744
- template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
3745
- template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
3746
- template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
3747
- template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
3748
- template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
3749
- template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
3750
- template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
4038
+ template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
4039
+ template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
4040
+ template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
4041
+ template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
4042
+ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
4043
+ template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
4044
+ template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
4045
+ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
4046
+ template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
3751
4047
  #endif
3752
4048
 
3753
4049
  template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
@@ -3801,6 +4097,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
3801
4097
  template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
3802
4098
 
3803
4099
  #undef FA_TYPES
4100
+ #undef FA_TYPES_BF
3804
4101
 
3805
4102
  template<
3806
4103
  typename q4_t, // query types in shared memory
@@ -3847,12 +4144,12 @@ kernel void kernel_flash_attn_ext_vec(
3847
4144
 
3848
4145
  const short T = DK + nsg*SH; // shared memory size per query in (half)
3849
4146
 
3850
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3851
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3852
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3853
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3854
- threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
3855
- threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
4147
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
4148
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
4149
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
4150
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
4151
+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
4152
+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results
3856
4153
 
3857
4154
  // store the result for all queries in local memory (the O matrix from the paper)
3858
4155
  o4_t lo[DV4/NL];
@@ -4157,7 +4454,7 @@ kernel void kernel_flash_attn_ext_vec(
4157
4454
  half4, \
4158
4455
  float, \
4159
4456
  float, float4, \
4160
- half4
4457
+ float4
4161
4458
 
4162
4459
  typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
4163
4460
 
@@ -4271,11 +4568,16 @@ kernel void kernel_cpy(
4271
4568
  device const char * src0,
4272
4569
  device char * dst,
4273
4570
  uint3 tgpig[[threadgroup_position_in_grid]],
4571
+ uint tiitg[[thread_index_in_threadgroup]],
4274
4572
  ushort3 tpitg[[thread_position_in_threadgroup]],
4275
- ushort3 ntg[[threads_per_threadgroup]]) {
4573
+ ushort3 tptg[[threads_per_threadgroup]]) {
4276
4574
  const int i03 = tgpig[2];
4277
4575
  const int i02 = tgpig[1];
4278
- const int i01 = tgpig[0];
4576
+ const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
4577
+
4578
+ if (i01 >= args.ne01) {
4579
+ return;
4580
+ }
4279
4581
 
4280
4582
  const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
4281
4583
 
@@ -4286,7 +4588,7 @@ kernel void kernel_cpy(
4286
4588
 
4287
4589
  device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
4288
4590
 
4289
- for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
4591
+ for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
4290
4592
  device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4291
4593
  dst_data[i00] = (T1) src[0];
4292
4594
  }
@@ -4306,6 +4608,7 @@ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bf
4306
4608
  template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
4307
4609
  #endif
4308
4610
 
4611
+ // TODO: templetify these kernels
4309
4612
  kernel void kernel_cpy_f32_q8_0(
4310
4613
  constant ggml_metal_kargs_cpy & args,
4311
4614
  device const char * src0,
@@ -4329,23 +4632,7 @@ kernel void kernel_cpy_f32_q8_0(
4329
4632
  for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
4330
4633
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4331
4634
 
4332
- float amax = 0.0f; // absolute max
4333
-
4334
- for (int j = 0; j < QK8_0; j++) {
4335
- const float v = src[j];
4336
- amax = MAX(amax, fabs(v));
4337
- }
4338
-
4339
- const float d = amax / ((1 << 7) - 1);
4340
- const float id = d ? 1.0f/d : 0.0f;
4341
-
4342
- dst_data[i00/QK8_0].d = d;
4343
-
4344
- for (int j = 0; j < QK8_0; ++j) {
4345
- const float x0 = src[j]*id;
4346
-
4347
- dst_data[i00/QK8_0].qs[j] = round(x0);
4348
- }
4635
+ quantize_q8_0(src, dst_data[i00/QK8_0]);
4349
4636
  }
4350
4637
  }
4351
4638
 
@@ -4372,32 +4659,7 @@ kernel void kernel_cpy_f32_q4_0(
4372
4659
  for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
4373
4660
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4374
4661
 
4375
- float amax = 0.0f; // absolute max
4376
- float max = 0.0f;
4377
-
4378
- for (int j = 0; j < QK4_0; j++) {
4379
- const float v = src[j];
4380
- if (amax < fabs(v)) {
4381
- amax = fabs(v);
4382
- max = v;
4383
- }
4384
- }
4385
-
4386
- const float d = max / -8;
4387
- const float id = d ? 1.0f/d : 0.0f;
4388
-
4389
- dst_data[i00/QK4_0].d = d;
4390
-
4391
- for (int j = 0; j < QK4_0/2; ++j) {
4392
- const float x0 = src[0 + j]*id;
4393
- const float x1 = src[QK4_0/2 + j]*id;
4394
-
4395
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
4396
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
4397
-
4398
- dst_data[i00/QK4_0].qs[j] = xi0;
4399
- dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
4400
- }
4662
+ quantize_q4_0(src, dst_data[i00/QK4_0]);
4401
4663
  }
4402
4664
  }
4403
4665
 
@@ -4424,31 +4686,7 @@ kernel void kernel_cpy_f32_q4_1(
4424
4686
  for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
4425
4687
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4426
4688
 
4427
- float min = FLT_MAX;
4428
- float max = -FLT_MAX;
4429
-
4430
- for (int j = 0; j < QK4_1; j++) {
4431
- const float v = src[j];
4432
- if (min > v) min = v;
4433
- if (max < v) max = v;
4434
- }
4435
-
4436
- const float d = (max - min) / ((1 << 4) - 1);
4437
- const float id = d ? 1.0f/d : 0.0f;
4438
-
4439
- dst_data[i00/QK4_1].d = d;
4440
- dst_data[i00/QK4_1].m = min;
4441
-
4442
- for (int j = 0; j < QK4_1/2; ++j) {
4443
- const float x0 = (src[0 + j] - min)*id;
4444
- const float x1 = (src[QK4_1/2 + j] - min)*id;
4445
-
4446
- const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
4447
- const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
4448
-
4449
- dst_data[i00/QK4_1].qs[j] = xi0;
4450
- dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
4451
- }
4689
+ quantize_q4_1(src, dst_data[i00/QK4_1]);
4452
4690
  }
4453
4691
  }
4454
4692
 
@@ -4475,38 +4713,7 @@ kernel void kernel_cpy_f32_q5_0(
4475
4713
  for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
4476
4714
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4477
4715
 
4478
- float amax = 0.0f; // absolute max
4479
- float max = 0.0f;
4480
-
4481
- for (int j = 0; j < QK5_0; j++) {
4482
- const float v = src[j];
4483
- if (amax < fabs(v)) {
4484
- amax = fabs(v);
4485
- max = v;
4486
- }
4487
- }
4488
-
4489
- const float d = max / -16;
4490
- const float id = d ? 1.0f/d : 0.0f;
4491
-
4492
- dst_data[i00/QK5_0].d = d;
4493
-
4494
- uint32_t qh = 0;
4495
- for (int j = 0; j < QK5_0/2; ++j) {
4496
- const float x0 = src[0 + j]*id;
4497
- const float x1 = src[QK5_0/2 + j]*id;
4498
-
4499
- const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
4500
- const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
4501
-
4502
- dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
4503
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
4504
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
4505
- }
4506
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4507
- for (int j = 0; j < 4; ++j) {
4508
- dst_data[i00/QK5_0].qh[j] = qh8[j];
4509
- }
4716
+ quantize_q5_0(src, dst_data[i00/QK5_0]);
4510
4717
  }
4511
4718
  }
4512
4719
 
@@ -4533,51 +4740,10 @@ kernel void kernel_cpy_f32_q5_1(
4533
4740
  for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
4534
4741
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4535
4742
 
4536
- float max = src[0];
4537
- float min = src[0];
4538
-
4539
- for (int j = 1; j < QK5_1; j++) {
4540
- const float v = src[j];
4541
- min = v < min ? v : min;
4542
- max = v > max ? v : max;
4543
- }
4544
-
4545
- const float d = (max - min) / 31;
4546
- const float id = d ? 1.0f/d : 0.0f;
4547
-
4548
- dst_data[i00/QK5_1].d = d;
4549
- dst_data[i00/QK5_1].m = min;
4550
-
4551
- uint32_t qh = 0;
4552
- for (int j = 0; j < QK5_1/2; ++j) {
4553
- const float x0 = (src[0 + j] - min)*id;
4554
- const float x1 = (src[QK5_1/2 + j] - min)*id;
4555
-
4556
- const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
4557
- const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
4558
-
4559
- dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
4560
- qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
4561
- qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
4562
- }
4563
- thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
4564
- for (int j = 0; j < 4; ++j) {
4565
- dst_data[i00/QK5_1].qh[j] = qh8[j];
4566
- }
4743
+ quantize_q5_1(src, dst_data[i00/QK5_1]);
4567
4744
  }
4568
4745
  }
4569
4746
 
4570
- static inline int best_index_int8(int n, constant float * val, float x) {
4571
- if (x <= val[0]) return 0;
4572
- if (x >= val[n-1]) return n-1;
4573
- int ml = 0, mu = n-1;
4574
- while (mu-ml > 1) {
4575
- int mav = (ml+mu)/2;
4576
- if (x < val[mav]) mu = mav; else ml = mav;
4577
- }
4578
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
4579
- }
4580
-
4581
4747
  kernel void kernel_cpy_f32_iq4_nl(
4582
4748
  constant ggml_metal_kargs_cpy & args,
4583
4749
  device const char * src0,
@@ -4601,40 +4767,7 @@ kernel void kernel_cpy_f32_iq4_nl(
4601
4767
  for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
4602
4768
  device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4603
4769
 
4604
- float amax = 0.0f; // absolute max
4605
- float max = 0.0f;
4606
-
4607
- for (int j = 0; j < QK4_NL; j++) {
4608
- const float v = src[j];
4609
- if (amax < fabs(v)) {
4610
- amax = fabs(v);
4611
- max = v;
4612
- }
4613
- }
4614
-
4615
- const float d = max / kvalues_iq4nl_f[0];
4616
- const float id = d ? 1.0f/d : 0.0f;
4617
-
4618
- float sumqx = 0, sumq2 = 0;
4619
- for (int j = 0; j < QK4_NL/2; ++j) {
4620
- const float x0 = src[0 + j]*id;
4621
- const float x1 = src[QK4_NL/2 + j]*id;
4622
-
4623
- const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
4624
- const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
4625
-
4626
- dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
4627
-
4628
- const float v0 = kvalues_iq4nl_f[xi0];
4629
- const float v1 = kvalues_iq4nl_f[xi1];
4630
- const float w0 = src[0 + j]*src[0 + j];
4631
- const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
4632
- sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
4633
- sumq2 += w0*v0*v0 + w1*v1*v1;
4634
-
4635
- }
4636
-
4637
- dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
4770
+ quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
4638
4771
  }
4639
4772
  }
4640
4773
 
@@ -6315,10 +6448,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
6315
6448
 
6316
6449
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
6317
6450
  kernel void kernel_get_rows_q(
6451
+ constant ggml_metal_kargs_get_rows & args,
6318
6452
  device const void * src0,
6319
6453
  device const void * src1,
6320
6454
  device float * dst,
6321
- constant ggml_metal_kargs_get_rows & args,
6322
6455
  uint3 tgpig[[threadgroup_position_in_grid]],
6323
6456
  uint tiitg[[thread_index_in_threadgroup]],
6324
6457
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6338,10 +6471,10 @@ kernel void kernel_get_rows_q(
6338
6471
 
6339
6472
  template<typename T>
6340
6473
  kernel void kernel_get_rows_f(
6474
+ constant ggml_metal_kargs_get_rows & args,
6341
6475
  device const void * src0,
6342
6476
  device const void * src1,
6343
6477
  device float * dst,
6344
- constant ggml_metal_kargs_get_rows & args,
6345
6478
  uint3 tgpig[[threadgroup_position_in_grid]],
6346
6479
  uint tiitg[[thread_index_in_threadgroup]],
6347
6480
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6359,10 +6492,10 @@ kernel void kernel_get_rows_f(
6359
6492
  }
6360
6493
 
6361
6494
  kernel void kernel_get_rows_i32(
6495
+ constant ggml_metal_kargs_get_rows & args,
6362
6496
  device const void * src0,
6363
6497
  device const void * src1,
6364
6498
  device int32_t * dst,
6365
- constant ggml_metal_kargs_get_rows & args,
6366
6499
  uint3 tgpig[[threadgroup_position_in_grid]],
6367
6500
  uint tiitg[[thread_index_in_threadgroup]],
6368
6501
  uint3 tptg [[threads_per_threadgroup]]) {
@@ -6379,6 +6512,67 @@ kernel void kernel_get_rows_i32(
6379
6512
  }
6380
6513
  }
6381
6514
 
6515
+ template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
6516
+ kernel void kernel_set_rows_q32(
6517
+ constant ggml_metal_kargs_set_rows & args,
6518
+ device const void * src0,
6519
+ device const void * src1,
6520
+ device float * dst,
6521
+ uint3 tgpig[[threadgroup_position_in_grid]],
6522
+ uint tiitg[[thread_index_in_threadgroup]],
6523
+ uint3 tptg [[threads_per_threadgroup]]) {
6524
+ const int32_t i03 = tgpig.z;
6525
+ const int32_t i02 = tgpig.y;
6526
+
6527
+ const int32_t i12 = i03%args.ne12;
6528
+ const int32_t i11 = i02%args.ne11;
6529
+
6530
+ const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
6531
+ if (i01 >= args.ne01) {
6532
+ return;
6533
+ }
6534
+
6535
+ const int32_t i10 = i01;
6536
+ const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
6537
+
6538
+ device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
6539
+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
6540
+
6541
+ for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
6542
+ quantize_func(src_row + 32*ind, dst_row[ind]);
6543
+ }
6544
+ }
6545
+
6546
+ template<typename T>
6547
+ kernel void kernel_set_rows_f(
6548
+ constant ggml_metal_kargs_set_rows & args,
6549
+ device const void * src0,
6550
+ device const void * src1,
6551
+ device float * dst,
6552
+ uint3 tgpig[[threadgroup_position_in_grid]],
6553
+ uint tiitg[[thread_index_in_threadgroup]],
6554
+ uint3 tptg [[threads_per_threadgroup]]) {
6555
+ const int32_t i03 = tgpig.z;
6556
+ const int32_t i02 = tgpig.y;
6557
+
6558
+ const int32_t i12 = i03%args.ne12;
6559
+ const int32_t i11 = i02%args.ne11;
6560
+
6561
+ const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
6562
+ if (i01 >= args.ne01) {
6563
+ return;
6564
+ }
6565
+
6566
+ const int32_t i10 = i01;
6567
+ const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
6568
+
6569
+ device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
6570
+ const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
6571
+
6572
+ for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
6573
+ dst_row[ind] = (T) src_row[ind];
6574
+ }
6575
+ }
6382
6576
 
6383
6577
  #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
6384
6578
  #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
@@ -6802,6 +6996,27 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get
6802
6996
  template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
6803
6997
  template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
6804
6998
 
6999
+ //
7000
+ // set rows
7001
+ //
7002
+
7003
+ typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
7004
+
7005
+ template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
7006
+ template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
7007
+ #if defined(GGML_METAL_USE_BF16)
7008
+ template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
7009
+ #endif
7010
+
7011
+ typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
7012
+
7013
+ template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
7014
+ template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
7015
+ template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
7016
+ template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
7017
+ template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
7018
+ template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
7019
+
6805
7020
  //
6806
7021
  // matrix-matrix multiplication
6807
7022
  //