@novastera-oss/llamarn 0.2.1 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. package/README.md +80 -14
  2. package/RNLlamaCpp.podspec +10 -3
  3. package/android/CMakeLists.txt +8 -0
  4. package/android/src/main/cpp/include/llama.h +62 -125
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  8. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  10. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  11. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  12. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  13. package/cpp/PureCppImpl.cpp +9 -27
  14. package/cpp/SystemUtils.h +2 -2
  15. package/cpp/build-info.cpp +2 -2
  16. package/cpp/llama.cpp/README.md +11 -3
  17. package/cpp/llama.cpp/build-xcframework.sh +1 -0
  18. package/cpp/llama.cpp/common/CMakeLists.txt +8 -2
  19. package/cpp/llama.cpp/common/arg.cpp +153 -113
  20. package/cpp/llama.cpp/common/chat-parser.cpp +379 -0
  21. package/cpp/llama.cpp/common/chat-parser.h +117 -0
  22. package/cpp/llama.cpp/common/chat.cpp +847 -699
  23. package/cpp/llama.cpp/common/chat.h +73 -6
  24. package/cpp/llama.cpp/common/common.cpp +50 -82
  25. package/cpp/llama.cpp/common/common.h +21 -17
  26. package/cpp/llama.cpp/common/json-partial.cpp +255 -0
  27. package/cpp/llama.cpp/common/json-partial.h +37 -0
  28. package/cpp/llama.cpp/common/minja/chat-template.hpp +9 -5
  29. package/cpp/llama.cpp/common/minja/minja.hpp +69 -36
  30. package/cpp/llama.cpp/common/regex-partial.cpp +204 -0
  31. package/cpp/llama.cpp/common/regex-partial.h +56 -0
  32. package/cpp/llama.cpp/common/sampling.cpp +7 -8
  33. package/cpp/llama.cpp/convert_hf_to_gguf.py +453 -118
  34. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +120 -68
  35. package/cpp/llama.cpp/ggml/CMakeLists.txt +2 -1
  36. package/cpp/llama.cpp/ggml/cmake/common.cmake +25 -0
  37. package/cpp/llama.cpp/ggml/include/ggml-opt.h +49 -28
  38. package/cpp/llama.cpp/ggml/include/ggml.h +26 -7
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +16 -10
  40. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +4 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  42. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +2 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +604 -0
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +42 -0
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +54 -2
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +50 -51
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -2
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -9
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +779 -19
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +22 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +322 -100
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +117 -1
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +85 -16
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +220 -49
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/acc.cu +40 -26
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +1 -1
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +11 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +15 -7
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +266 -64
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +49 -4
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +48 -4
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +2 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +5 -1
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +2 -0
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +7 -6
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/sum.cu +1 -1
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +10 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-impl.h +1 -1
  73. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +4 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +99 -17
  75. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +200 -2
  76. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +8 -2
  77. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cu +112 -0
  78. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +12 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +6 -0
  80. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +972 -178
  81. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  82. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  84. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-opt.cpp +373 -190
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  89. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -10
  90. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +101 -5
  91. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +31 -33
  92. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +1 -0
  93. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +29 -2
  94. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +4 -5
  95. package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  96. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +9 -1
  97. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +84 -72
  98. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  100. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -3
  101. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +324 -129
  102. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +1 -0
  103. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +31 -2
  104. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +95 -68
  105. package/cpp/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +1 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +22 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -2
  108. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -4
  109. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +2 -3
  110. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +69 -43
  111. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +2 -14
  112. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -91
  113. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -181
  114. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +17 -0
  115. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  116. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +6 -152
  117. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  118. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  119. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +2 -118
  120. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +1 -1
  121. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +12 -1
  122. package/cpp/llama.cpp/ggml/src/ggml.c +107 -36
  123. package/cpp/llama.cpp/ggml/src/gguf.cpp +33 -33
  124. package/cpp/llama.cpp/gguf-py/gguf/constants.py +100 -15
  125. package/cpp/llama.cpp/gguf-py/gguf/gguf_reader.py +1 -1
  126. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +44 -12
  127. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_editor_gui.py +21 -10
  128. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_new_metadata.py +5 -2
  129. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +128 -31
  130. package/cpp/llama.cpp/gguf-py/gguf/utility.py +1 -1
  131. package/cpp/llama.cpp/gguf-py/pyproject.toml +1 -1
  132. package/cpp/llama.cpp/include/llama.h +62 -125
  133. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +1 -1
  134. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +1 -1
  135. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +1 -1
  136. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +1 -1
  137. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +1 -1
  138. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +1 -1
  139. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +1 -1
  140. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +1 -1
  141. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +1 -1
  142. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +1 -1
  143. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +1 -1
  144. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +1 -1
  145. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +1 -1
  146. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +1 -1
  147. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +1 -1
  148. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +1 -1
  149. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +1 -1
  150. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +1 -1
  151. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  152. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +1 -1
  153. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +1 -1
  154. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +1 -1
  155. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +1 -1
  156. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +1 -1
  157. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +1 -1
  158. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +1 -1
  159. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +1 -1
  160. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +62 -0
  161. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +85 -0
  162. package/cpp/llama.cpp/models/templates/README.md +2 -0
  163. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +5 -1
  164. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +5 -1
  165. package/cpp/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  166. package/cpp/llama.cpp/requirements/requirements-gguf_editor_gui.txt +1 -1
  167. package/cpp/llama.cpp/src/CMakeLists.txt +2 -0
  168. package/cpp/llama.cpp/src/llama-arch.cpp +6 -0
  169. package/cpp/llama.cpp/src/llama-arch.h +2 -0
  170. package/cpp/llama.cpp/src/llama-batch.cpp +3 -1
  171. package/cpp/llama.cpp/src/llama-context.cpp +340 -123
  172. package/cpp/llama.cpp/src/llama-context.h +30 -0
  173. package/cpp/llama.cpp/src/llama-cparams.cpp +4 -0
  174. package/cpp/llama.cpp/src/llama-cparams.h +2 -0
  175. package/cpp/llama.cpp/src/llama-grammar.cpp +12 -2
  176. package/cpp/llama.cpp/src/llama-graph.cpp +157 -247
  177. package/cpp/llama.cpp/src/llama-graph.h +52 -7
  178. package/cpp/llama.cpp/src/llama-hparams.cpp +17 -1
  179. package/cpp/llama.cpp/src/llama-hparams.h +37 -5
  180. package/cpp/llama.cpp/src/llama-kv-cache.cpp +742 -481
  181. package/cpp/llama.cpp/src/llama-kv-cache.h +196 -99
  182. package/cpp/llama.cpp/src/llama-kv-cells.h +379 -0
  183. package/cpp/llama.cpp/src/llama-memory.h +4 -3
  184. package/cpp/llama.cpp/src/llama-model-loader.cpp +22 -17
  185. package/cpp/llama.cpp/src/llama-model-saver.cpp +281 -0
  186. package/cpp/llama.cpp/src/llama-model-saver.h +37 -0
  187. package/cpp/llama.cpp/src/llama-model.cpp +529 -172
  188. package/cpp/llama.cpp/src/llama-model.h +6 -1
  189. package/cpp/llama.cpp/src/llama-quant.cpp +15 -13
  190. package/cpp/llama.cpp/src/llama-sampling.cpp +2 -2
  191. package/cpp/llama.cpp/src/llama-vocab.cpp +35 -8
  192. package/cpp/llama.cpp/src/llama-vocab.h +6 -0
  193. package/cpp/llama.cpp/src/llama.cpp +14 -0
  194. package/cpp/rn-completion.cpp +60 -5
  195. package/ios/include/chat.h +73 -6
  196. package/ios/include/common/minja/chat-template.hpp +9 -5
  197. package/ios/include/common/minja/minja.hpp +69 -36
  198. package/ios/include/common.h +21 -17
  199. package/ios/include/llama.h +62 -125
  200. package/ios/libs/llama.xcframework/Info.plist +19 -19
  201. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  202. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4617 -4487
  203. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  204. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +26 -7
  205. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +62 -125
  206. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  207. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  208. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  209. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3557 -3435
  210. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  211. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  212. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  213. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  214. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  215. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4638 -4508
  216. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3559 -3437
  217. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +237 -0
  218. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +26 -7
  219. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +62 -125
  220. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +237 -0
  221. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +26 -7
  222. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +62 -125
  223. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  224. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +237 -0
  225. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +26 -7
  226. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +62 -125
  227. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  228. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  229. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  230. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4616 -4487
  231. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  232. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +26 -7
  233. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +62 -125
  234. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  235. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  236. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4637 -4508
  237. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3556 -3435
  238. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  239. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  240. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  241. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  242. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  243. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4653 -4523
  244. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +237 -0
  245. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +26 -7
  246. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +62 -125
  247. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4674 -4544
  250. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3587 -3465
  251. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +237 -0
  252. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +26 -7
  253. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +62 -125
  254. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  255. package/package.json +1 -1
  256. package/cpp/llama.cpp/common/stb_image.h +0 -7988
  257. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +0 -112
  258. package/cpp/llama.cpp/models/ggml-vocab-chameleon.gguf.out +0 -46
  259. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +0 -112
  260. package/cpp/llama.cpp/models/ggml-vocab-deepseek-r1-qwen.gguf.out +0 -46
  261. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +0 -112
  262. package/cpp/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +0 -46
  263. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.inp +0 -112
  264. package/cpp/llama.cpp/models/ggml-vocab-llama4.gguf.out +0 -46
  265. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +0 -112
  266. package/cpp/llama.cpp/models/ggml-vocab-pixtral.gguf.out +0 -46
  267. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +0 -112
  268. package/cpp/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +0 -46
@@ -0,0 +1,360 @@
1
+ #version 450
2
+
3
+ #extension GL_EXT_control_flow_attributes : enable
4
+ #extension GL_EXT_shader_16bit_storage : require
5
+
6
+ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
7
+ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
8
+
9
+ #extension GL_KHR_shader_subgroup_basic : enable
10
+ #extension GL_KHR_memory_scope_semantics : enable
11
+ #extension GL_KHR_cooperative_matrix : enable
12
+
13
+ #include "types.comp"
14
+ #include "flash_attn_base.comp"
15
+
16
+ const uint32_t D_per_thread = D / D_split;
17
+ const uint32_t row_split = 4;
18
+ const uint32_t rows_per_thread = Br / row_split;
19
+ const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
20
+ const uint32_t cols_per_thread = Bc / cols_per_iter;
21
+
22
+
23
+ layout (binding = 0) readonly buffer Q {float data_q[];};
24
+ layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
25
+ layout (binding = 1) readonly buffer K {float16_t data_k[];};
26
+ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
27
+ layout (binding = 2) readonly buffer V {float16_t data_v[];};
28
+ layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
29
+ layout (binding = 3) readonly buffer M {float16_t data_m[];};
30
+
31
+ // Store the output when doing grouped query attention.
32
+ // Rows index by Q's dimension 2, and the first N rows are valid.
33
+ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
34
+ {
35
+ uint32_t offset = (iq2 + r) * D + c;
36
+ data_o[o_offset + offset] = D_TYPE(elem);
37
+ return elem;
38
+ }
39
+
40
+ // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
41
+ const uint32_t MatBr = 16;
42
+ const uint32_t MatBc = 16;
43
+
44
+ shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
45
+ shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
46
+
47
+ const uint32_t qstride = D / 4 + 2; // in units of f16vec4
48
+ shared f16vec4 Qf[Br * qstride];
49
+
50
+ // Avoid padding for D==256 to make it fit in 48KB shmem.
51
+ const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
52
+ shared ACC_TYPE sfsh[Bc * sfshstride];
53
+
54
+ const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
55
+ shared f16vec4 ksh[Bc * kshstride];
56
+
57
+ shared float slope[Br];
58
+
59
+ void main() {
60
+ #ifdef NEEDS_INIT_IQ_SHMEM
61
+ init_iq_shmem(gl_WorkGroupSize);
62
+ #endif
63
+
64
+ init_indices();
65
+
66
+ const uint32_t tid = gl_LocalInvocationIndex;
67
+
68
+ const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
69
+ const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
70
+ const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
71
+ const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
72
+
73
+ #define tile_row(r) (row_tid * rows_per_thread + (r))
74
+
75
+ uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
76
+
77
+ [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
78
+ uint32_t d = (idx + tid) % (D / 4);
79
+ uint32_t r = (idx + tid) / (D / 4);
80
+ if (r < Br && d < D / 4 &&
81
+ i * Br + r < N) {
82
+ Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
83
+ }
84
+ }
85
+ barrier();
86
+
87
+ ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
88
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
89
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
90
+ Of[r][d] = ACC_TYPEV4(0.0);
91
+ }
92
+ }
93
+
94
+ float Lf[rows_per_thread], Mf[rows_per_thread];
95
+
96
+ // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
97
+ const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
98
+
99
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
100
+ Lf[r] = 0;
101
+ Mf[r] = NEG_FLT_MAX_OVER_2;
102
+ }
103
+
104
+ // ALiBi
105
+ if (p.max_bias > 0.0f) {
106
+ if (tid < Br) {
107
+ uint r = tid;
108
+ slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
109
+ }
110
+ barrier();
111
+ } else {
112
+ if (tid < Br) {
113
+ uint r = tid;
114
+ slope[r] = 1.0;
115
+ }
116
+ barrier();
117
+ }
118
+
119
+ #if BLOCK_SIZE > 1
120
+ uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
121
+ uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
122
+ #else
123
+ uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
124
+ uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
125
+ #endif
126
+
127
+ [[dont_unroll]]
128
+ for (uint32_t j = start_j; j < end_j; ++j) {
129
+
130
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
131
+ uint32_t d = (idx + tid) % (D / 4);
132
+ uint32_t c = (idx + tid) / (D / 4);
133
+ if (c < Bc && d < D / 4) {
134
+ #if BLOCK_SIZE > 1
135
+ uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
136
+ uint ib = coord / BLOCK_SIZE;
137
+ uint iqs = (coord % BLOCK_SIZE);
138
+ f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
139
+ #else
140
+ f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
141
+ #endif
142
+
143
+ ksh[c * kshstride + d] = K_Tf;
144
+ }
145
+ }
146
+ barrier();
147
+
148
+ // K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
149
+ // Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
150
+ // This is written transposed in order to allow for N being 8 if implementations need it
151
+ coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
152
+ coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
153
+ coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
154
+
155
+ for (uint32_t d = 0; d < D / 16; ++d) {
156
+ coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
157
+
158
+ uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
159
+ coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
160
+
161
+ SfMat = coopMatMulAdd(KMat, QMat, SfMat);
162
+ }
163
+
164
+ uint coord = gl_SubgroupID * MatBc * sfshstride;
165
+ coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
166
+ barrier();
167
+
168
+ if (p.logit_softcap != 0.0f) {
169
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
170
+ uint32_t c = (idx + tid) / Br;
171
+ uint32_t r = (idx + tid) % Br;
172
+ if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
173
+ sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
174
+ }
175
+ }
176
+ barrier();
177
+ }
178
+
179
+ if (p.mask != 0) {
180
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
181
+ uint32_t c = (idx + tid) % Bc;
182
+ uint32_t r = (idx + tid) / Bc;
183
+ if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
184
+ sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
185
+ }
186
+ }
187
+ barrier();
188
+ }
189
+
190
+ float eMf[rows_per_thread];
191
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
192
+ float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
193
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
194
+ rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
195
+ }
196
+ float Moldf = Mf[r];
197
+
198
+ // M = max(rowmax, Mold)
199
+ // P = e^(S - M)
200
+ // eM = e^(Mold - M)
201
+ Mf[r] = max(rowmaxf, Moldf);
202
+ eMf[r] = exp(Moldf - Mf[r]);
203
+ }
204
+
205
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
206
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
207
+ Of[r][d] = float16_t(eMf[r]) * Of[r][d];
208
+ }
209
+ }
210
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
211
+ Lf[r] = eMf[r]*Lf[r];
212
+ }
213
+
214
+ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
215
+ float Pf[rows_per_thread];
216
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
217
+ Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
218
+ Lf[r] += Pf[r];
219
+ }
220
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
221
+ #if BLOCK_SIZE > 1
222
+ uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
223
+ uint ib = coord / BLOCK_SIZE;
224
+ uint iqs = (coord % BLOCK_SIZE);
225
+ vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
226
+ #else
227
+ vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
228
+ #endif
229
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
230
+ Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf);
231
+ }
232
+ }
233
+ }
234
+
235
+ barrier();
236
+ }
237
+
238
+ // reduce across threads
239
+
240
+ float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
241
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
242
+ FLOAT_TYPE M = Mf[r];
243
+ tmpsh[tid] = M;
244
+ // Compute max across the row
245
+ barrier();
246
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
247
+ M = max(M, tmpsh[tid ^ s]);
248
+ barrier();
249
+ tmpsh[tid] = M;
250
+ barrier();
251
+ }
252
+ rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
253
+ barrier();
254
+ }
255
+
256
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
257
+ Moldf[r] = Mf[r];
258
+
259
+ // M = max(rowmax, Mold)
260
+ // eM = e^(Mold - M)
261
+ Mf[r] = max(rowmaxf[r], Moldf[r]);
262
+ eMf[r] = exp(Moldf[r] - Mf[r]);
263
+
264
+ Lf[r] = eMf[r]*Lf[r];
265
+ }
266
+
267
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
268
+ FLOAT_TYPE L = Lf[r];
269
+ tmpsh[tid] = L;
270
+ // Compute sum across the row
271
+ barrier();
272
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
273
+ L += tmpsh[tid ^ s];
274
+ barrier();
275
+ tmpsh[tid] = L;
276
+ barrier();
277
+ }
278
+ Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
279
+ barrier();
280
+ }
281
+
282
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
283
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
284
+
285
+ Of[r][d] = float16_t(eMf[r]) * Of[r][d];
286
+ tmpshv4[tid] = Of[r][d];
287
+
288
+ barrier();
289
+ [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
290
+ Of[r][d] += tmpshv4[tid ^ s];
291
+ barrier();
292
+ tmpshv4[tid] = Of[r][d];
293
+ barrier();
294
+ }
295
+ Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
296
+ barrier();
297
+ }
298
+ }
299
+
300
+ // If there is split_k, then the split_k resolve shader does the final
301
+ // division by L. Store the intermediate O value and per-row m and L values.
302
+ if (p.k_num > 1) {
303
+ uint32_t o_offset = D * p.ne1 * split_k_index;
304
+
305
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
306
+ if (tile_row(r) < N) {
307
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
308
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
309
+ perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
310
+ }
311
+ }
312
+ }
313
+ }
314
+
315
+ o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
316
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
317
+ if (tile_row(r) < N) {
318
+ perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
319
+ perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
320
+ }
321
+ }
322
+
323
+ return;
324
+ }
325
+
326
+ float Lfrcp[rows_per_thread];
327
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
328
+ Lfrcp[r] = 1.0 / Lf[r];
329
+ }
330
+
331
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
332
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
333
+ Of[r][d] *= float16_t(Lfrcp[r]);
334
+ }
335
+ }
336
+
337
+ uint32_t o_offset = iq3*p.ne2*p.ne1;
338
+
339
+ if (p.gqa_ratio > 1) {
340
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
341
+ if (tile_row(r) < N) {
342
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
343
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
344
+ perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
345
+ }
346
+ }
347
+ }
348
+ }
349
+ } else {
350
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
351
+ if (i * Br + tile_row(r) < N) {
352
+ [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
353
+ [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
354
+ data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
355
+ }
356
+ }
357
+ }
358
+ }
359
+ }
360
+ }
@@ -18,62 +18,12 @@
18
18
 
19
19
  #include "types.comp"
20
20
  #include "dequant_funcs_cm2.comp"
21
-
22
- layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
23
-
24
- layout (constant_id = 1) const uint32_t Br = 32;
25
- layout (constant_id = 2) const uint32_t Bc = 32;
26
- layout (constant_id = 3) const uint32_t D = 32;
27
- layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
28
-
29
- layout (push_constant) uniform parameter {
30
- uint32_t N;
31
- uint32_t KV;
32
-
33
- uint32_t ne1;
34
- uint32_t ne2;
35
- uint32_t ne3;
36
-
37
- uint32_t neq2;
38
- uint32_t neq3;
39
- uint32_t nek2;
40
- uint32_t nek3;
41
- uint32_t nev2;
42
- uint32_t nev3;
43
- uint32_t nem1;
44
-
45
- uint32_t nb01;
46
- uint32_t nb02;
47
- uint32_t nb03;
48
- uint32_t nb11;
49
- uint32_t nb12;
50
- uint32_t nb13;
51
- uint32_t nb21;
52
- uint32_t nb22;
53
- uint32_t nb23;
54
- uint32_t nb31;
55
-
56
- float scale;
57
- float max_bias;
58
- float logit_softcap;
59
-
60
- uint32_t mask;
61
- uint32_t n_head_log2;
62
- float m0;
63
- float m1;
64
-
65
- uint32_t gqa_ratio;
66
- uint32_t split_kv;
67
- uint32_t k_num;
68
- } p;
21
+ #include "flash_attn_base.comp"
69
22
 
70
23
  layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
71
24
  layout (binding = 1) readonly buffer K {uint8_t data_k[];};
72
25
  layout (binding = 2) readonly buffer V {uint8_t data_v[];};
73
26
  layout (binding = 3) readonly buffer M {uint8_t data_m[];};
74
- layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
75
-
76
- #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
77
27
 
78
28
  ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
79
29
  return max(x, y);
@@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
118
68
  return elem;
119
69
  }
120
70
 
121
- // Store column zero. This is used to save per-row m and L values for split_k.
122
- ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
123
- {
124
- if (r < N && c == 0) {
125
- uint32_t offset = iq2 + r;
126
- data_o[o_offset + offset] = D_TYPE(elem);
127
- }
128
- return elem;
129
- }
130
-
131
- // Load the slope matrix, indexed by Q's dimension 2.
132
- ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
133
- {
134
- const uint32_t h = iq2 + (r % p.gqa_ratio);
135
-
136
- const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
137
- const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
138
-
139
- return ACC_TYPE(pow(base, ACC_TYPE(exph)));
140
- }
141
-
142
71
  void main() {
143
72
  #ifdef NEEDS_INIT_IQ_SHMEM
144
73
  init_iq_shmem(gl_WorkGroupSize);
145
74
  #endif
146
75
 
147
- const uint32_t N = p.N;
148
- const uint32_t KV = p.KV;
149
-
150
- uint32_t i = gl_WorkGroupID.x;
151
- uint32_t split_k_index = 0;
152
-
153
- if (p.k_num > 1) {
154
- i = 0;
155
- split_k_index = gl_WorkGroupID.x;
156
- }
157
-
158
- const uint32_t Tr = CEIL_DIV(N, Br);
159
-
160
- const uint32_t start_j = split_k_index * p.split_kv / Bc;
161
- const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
162
-
163
- // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
164
- // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
165
- const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
166
- const uint32_t iq3 = gl_WorkGroupID.z;
167
-
168
- // broadcast factors
169
- const uint32_t rk2 = p.neq2/p.nek2;
170
- const uint32_t rk3 = p.neq3/p.nek3;
171
-
172
- const uint32_t rv2 = p.neq2/p.nev2;
173
- const uint32_t rv3 = p.neq3/p.nev3;
174
-
175
- // k indices
176
- const uint32_t ik3 = iq3 / rk3;
177
- const uint32_t ik2 = iq2 / rk2;
178
-
179
- // v indices
180
- const uint32_t iv3 = iq3 / rv3;
181
- const uint32_t iv2 = iq2 / rv2;
76
+ init_indices();
182
77
 
183
78
  tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
184
79
  tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
@@ -195,17 +90,6 @@ void main() {
195
90
  tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
196
91
  tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
197
92
 
198
- // nb?1 are already divided by the type size and are in units of elements.
199
- // When using grouped query attention, Q is indexed by iq2, so the stride
200
- // should be nb02 (which is in bytes).
201
- uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
202
- uint32_t k_stride = p.nb11;
203
- uint32_t v_stride = p.nb21;
204
- // When using grouped query attention, all rows use the same mask (stride 0).
205
- // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
206
- // that prevents the compiler from folding the "&" through the select
207
- // and breaking the alignment detection.
208
- uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
209
93
  // hint to the compiler that strides are aligned for the aligned variant of the shader
210
94
  if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
211
95
  {
@@ -7,7 +7,7 @@
7
7
  #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
8
8
  #endif
9
9
  #if defined(DATA_A_IQ1_M)
10
- #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
10
+ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
11
11
  #endif
12
12
 
13
13
  #if defined(DATA_A_BF16) && defined(COOPMAT)
@@ -215,7 +215,7 @@ static std::mutex compile_count_mutex;
215
215
  static std::condition_variable compile_count_cond;
216
216
 
217
217
  void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
218
- std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
218
+ std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
219
219
  std::string out_fname = join_paths(output_dir, name + ".spv");
220
220
  std::string in_path = join_paths(input_dir, in_fname);
221
221
 
@@ -424,6 +424,7 @@ void process_shaders() {
424
424
  // flash attention
425
425
  for (const auto& f16acc : {false, true}) {
426
426
  std::string acctype = f16acc ? "float16_t" : "float";
427
+ std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
427
428
 
428
429
  for (const auto& tname : type_names) {
429
430
  if (tname == "f32") {
@@ -440,6 +441,16 @@ void process_shaders() {
440
441
  string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
441
442
  merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
442
443
  }
444
+ #endif
445
+ #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
446
+ if (tname == "f16") {
447
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
448
+ merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
449
+ } else if (tname == "q4_0" || tname == "q8_0") {
450
+ std::string data_a_key = "DATA_A_" + to_uppercase(tname);
451
+ string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
452
+ merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
453
+ }
443
454
  #endif
444
455
  if (tname == "f16") {
445
456
  string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",