@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -19,6 +19,10 @@ if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
19
19
  add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
20
20
  message(STATUS "Enabling bfloat16 glslc support")
21
21
  endif()
22
+ if (GGML_VULKAN_SHADER_DEBUG_INFO)
23
+ add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
24
+ message(STATUS "Enabling shader debug info")
25
+ endif()
22
26
 
23
27
  set(TARGET vulkan-shaders-gen)
24
28
  add_executable(${TARGET} vulkan-shaders-gen.cpp)
@@ -6,17 +6,25 @@ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bi
6
6
  #endif // RTE16
7
7
 
8
8
  #include "types.comp"
9
- #include "generic_unary_head.comp"
10
9
 
11
- #if defined(DATA_A_IQ4_NL)
12
- // 16 invocations needed for init_iq4nl_shmem
13
- layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
10
+ #if defined(SET_ROWS) && QUANT_K == 1
11
+ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
12
+ const uint BLOCK_SIZE = 512;
14
13
  #else
15
- layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
14
+ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
15
+ const uint BLOCK_SIZE = 32;
16
16
  #endif
17
17
 
18
18
  layout (binding = 0) readonly buffer S {float data_s[];};
19
+
20
+ #if defined(SET_ROWS)
21
+ #include "generic_binary_head.comp"
22
+ layout (binding = 1) readonly buffer C {uvec2 data_i[];};
23
+ layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
24
+ #else
25
+ #include "generic_unary_head.comp"
19
26
  layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
27
+ #endif
20
28
 
21
29
  #if defined(DATA_A_Q4_0)
22
30
  void quantize(uint dst_idx, uint src_idx)
@@ -221,15 +229,56 @@ void quantize(uint dst_idx, uint src_idx)
221
229
  }
222
230
  #endif
223
231
 
232
+ #if defined(DATA_A_F32) || defined(DATA_A_F16)
233
+ void quantize(uint dst_idx, uint src_idx)
234
+ {
235
+ data_q[dst_idx] = A_TYPE(data_s[src_idx]);
236
+ }
237
+ #endif
238
+
239
+ #if defined(DATA_A_BF16)
240
+ void quantize(uint dst_idx, uint src_idx)
241
+ {
242
+ data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
243
+ }
244
+ #endif
245
+
246
+ #if defined(SET_ROWS)
247
+
224
248
  void main() {
225
249
  #ifdef NEEDS_INIT_IQ_SHMEM
226
250
  init_iq_shmem(gl_WorkGroupSize);
227
- if (gl_LocalInvocationIndex.x != 0) {
251
+ #endif
252
+
253
+ const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
254
+
255
+ if (idx >= p.ne) {
228
256
  return;
229
257
  }
258
+
259
+ uint i00, i01, i02, i03;
260
+ get_indices(idx, i00, i01, i02, i03);
261
+
262
+ uint i12 = fastmod(i03, p.ne12);
263
+ uint i11 = fastmod(i02, p.ne11);
264
+ uint i10 = i01;
265
+
266
+ uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
267
+
268
+ uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
269
+ uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
270
+
271
+ quantize(dst_idx, src0_idx);
272
+ }
273
+
274
+ #else
275
+
276
+ void main() {
277
+ #ifdef NEEDS_INIT_IQ_SHMEM
278
+ init_iq_shmem(gl_WorkGroupSize);
230
279
  #endif
231
280
 
232
- const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
281
+ const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
233
282
 
234
283
  if (idx >= p.ne) {
235
284
  return;
@@ -240,3 +289,5 @@ void main() {
240
289
 
241
290
  quantize(dst_idx, src_idx);
242
291
  }
292
+
293
+ #endif
@@ -11,7 +11,8 @@
11
11
  #include "types.comp"
12
12
  #include "flash_attn_base.comp"
13
13
 
14
- const uint32_t D_per_thread = D / D_split;
14
+ const uint32_t HSK_per_thread = HSK / D_split;
15
+ const uint32_t HSV_per_thread = HSV / D_split;
15
16
 
16
17
  const uint32_t cols_per_iter = WorkGroupSize / D_split;
17
18
  const uint32_t cols_per_thread = Bc / cols_per_iter;
@@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
29
30
  // Rows index by Q's dimension 2, and the first N rows are valid.
30
31
  D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
31
32
  {
32
- uint32_t offset = (iq2 + r) * D + c;
33
+ uint32_t offset = (iq2 + r) * HSV + c;
33
34
  data_o[o_offset + offset] = D_TYPE(elem);
34
35
  return elem;
35
36
  }
@@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
38
39
  shared vec4 tmpshv4[WorkGroupSize];
39
40
 
40
41
  shared float masksh[Bc][Br];
41
- shared vec4 Qf[Br][D / 4];
42
+ shared vec4 Qf[Br][HSK / 4];
42
43
 
43
44
  void main() {
44
45
  #ifdef NEEDS_INIT_IQ_SHMEM
@@ -53,18 +54,18 @@ void main() {
53
54
 
54
55
  uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
55
56
 
56
- [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
57
- uint32_t d = (idx + tid) % (D / 4);
58
- uint32_t r = (idx + tid) / (D / 4);
59
- if (r < Br && d < D / 4 &&
57
+ [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
58
+ uint32_t d = (idx + tid) % (HSK / 4);
59
+ uint32_t r = (idx + tid) / (HSK / 4);
60
+ if (r < Br && d < HSK / 4 &&
60
61
  i * Br + r < N) {
61
62
  Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
62
63
  }
63
64
  }
64
65
  barrier();
65
66
 
66
- vec4 Of[Br][D_per_thread / 4];
67
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
67
+ vec4 Of[Br][HSV_per_thread / 4];
68
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
68
69
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
69
70
  Of[r][d] = vec4(0.0);
70
71
  }
@@ -99,6 +100,10 @@ void main() {
99
100
  uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
100
101
  uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
101
102
  #endif
103
+ uint32_t m_offset = 0;
104
+ if (p.nem2 != 1 || p.nem3 != 1) {
105
+ m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
106
+ }
102
107
 
103
108
  [[dont_unroll]]
104
109
  for (uint32_t j = start_j; j < end_j; ++j) {
@@ -112,7 +117,7 @@ void main() {
112
117
 
113
118
 
114
119
  [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
115
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
120
+ [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
116
121
  #if BLOCK_SIZE > 1
117
122
  uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
118
123
  uint ib = coord / BLOCK_SIZE;
@@ -144,13 +149,13 @@ void main() {
144
149
  }
145
150
  }
146
151
 
147
- if (p.mask != 0) {
152
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
148
153
 
149
154
  [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
150
155
  uint32_t c = (idx + tid) % Bc;
151
156
  uint32_t r = (idx + tid) / Bc;
152
157
  if (idx + tid < Bc * Br) {
153
- masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
158
+ masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
154
159
  }
155
160
  }
156
161
  barrier();
@@ -191,14 +196,14 @@ void main() {
191
196
  Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
192
197
  }
193
198
 
194
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
199
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
195
200
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
196
201
  Of[r][d] = eMf[r] * Of[r][d];
197
202
  }
198
203
  }
199
204
 
200
205
  [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
201
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
206
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
202
207
  #if BLOCK_SIZE > 1
203
208
  uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
204
209
  uint ib = coord / BLOCK_SIZE;
@@ -255,7 +260,7 @@ void main() {
255
260
  Lf[r] = tmpsh[d_tid];
256
261
  barrier();
257
262
 
258
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
263
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
259
264
 
260
265
  Of[r][d] = eMf * Of[r][d];
261
266
  tmpshv4[tid] = Of[r][d];
@@ -277,11 +282,11 @@ void main() {
277
282
  // If there is split_k, then the split_k resolve shader does the final
278
283
  // division by L. Store the intermediate O value and per-row m and L values.
279
284
  if (p.k_num > 1) {
280
- uint32_t o_offset = D * p.ne1 * split_k_index;
285
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
281
286
 
282
287
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
283
288
  if (r < N) {
284
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
289
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
285
290
  [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
286
291
  perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
287
292
  }
@@ -289,7 +294,7 @@ void main() {
289
294
  }
290
295
  }
291
296
 
292
- o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
297
+ o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
293
298
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
294
299
  if (r < N) {
295
300
  perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -305,18 +310,18 @@ void main() {
305
310
  Lfrcp[r] = 1.0 / Lf[r];
306
311
  }
307
312
 
308
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
313
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
309
314
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
310
315
  Of[r][d] *= Lfrcp[r];
311
316
  }
312
317
  }
313
318
 
314
- uint32_t o_offset = iq3*p.ne2*p.ne1;
319
+ uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
315
320
 
316
321
  if (p.gqa_ratio > 1) {
317
322
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
318
323
  if (r < N) {
319
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
324
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
320
325
  [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
321
326
  perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
322
327
  }
@@ -326,9 +331,9 @@ void main() {
326
331
  } else {
327
332
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
328
333
  if (i * Br + r < N) {
329
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
334
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
330
335
  [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
331
- data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
336
+ data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
332
337
  }
333
338
  }
334
339
  }
@@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
4
4
  layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
5
5
  layout (constant_id = 1) const uint32_t Br = 1;
6
6
  layout (constant_id = 2) const uint32_t Bc = 32;
7
- layout (constant_id = 3) const uint32_t D = 32;
8
- layout (constant_id = 4) const uint32_t Clamp = 0;
9
- layout (constant_id = 5) const uint32_t D_split = 16;
10
-
7
+ layout (constant_id = 3) const uint32_t HSK = 32;
8
+ layout (constant_id = 4) const uint32_t HSV = 32;
9
+ layout (constant_id = 5) const uint32_t Clamp = 0;
10
+ layout (constant_id = 6) const uint32_t D_split = 16;
11
11
 
12
12
  layout (push_constant) uniform parameter {
13
13
  uint32_t N;
@@ -24,6 +24,8 @@ layout (push_constant) uniform parameter {
24
24
  uint32_t nev2;
25
25
  uint32_t nev3;
26
26
  uint32_t nem1;
27
+ uint32_t nem2;
28
+ uint32_t nem3;
27
29
 
28
30
  uint32_t nb01;
29
31
  uint32_t nb02;
@@ -34,14 +36,12 @@ layout (push_constant) uniform parameter {
34
36
  uint32_t nb21;
35
37
  uint32_t nb22;
36
38
  uint32_t nb23;
37
- uint32_t nb31;
38
39
 
39
40
  float scale;
40
41
  float max_bias;
41
42
  float logit_softcap;
42
43
 
43
- uint32_t mask;
44
- uint32_t n_head_log2;
44
+ uint32_t mask_n_head_log2;
45
45
  float m0;
46
46
  float m1;
47
47
 
@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
50
50
  uint32_t k_num;
51
51
  } p;
52
52
 
53
+ #define MASK_ENABLE_BIT (1<<16)
54
+ #define N_LOG2_MASK 0xFFFF
55
+
53
56
  layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
54
57
 
55
58
  #if defined(A_TYPE_PACKED16)
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
100
103
  {
101
104
  const uint32_t h = iq2 + (r % p.gqa_ratio);
102
105
 
103
- const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
104
- const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
106
+ uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
107
+
108
+ const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
109
+ const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
105
110
 
106
111
  return ACC_TYPE(pow(base, ACC_TYPE(exph)));
107
112
  }
@@ -13,7 +13,9 @@
13
13
  #include "types.comp"
14
14
  #include "flash_attn_base.comp"
15
15
 
16
- const uint32_t D_per_thread = D / D_split;
16
+ const uint32_t HSK_per_thread = HSK / D_split;
17
+ const uint32_t HSV_per_thread = HSV / D_split;
18
+
17
19
  const uint32_t row_split = 4;
18
20
  const uint32_t rows_per_thread = Br / row_split;
19
21
  const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
@@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
32
34
  // Rows index by Q's dimension 2, and the first N rows are valid.
33
35
  D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
34
36
  {
35
- uint32_t offset = (iq2 + r) * D + c;
37
+ uint32_t offset = (iq2 + r) * HSV + c;
36
38
  data_o[o_offset + offset] = D_TYPE(elem);
37
39
  return elem;
38
40
  }
@@ -44,14 +46,14 @@ const uint32_t MatBc = 16;
44
46
  shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
45
47
  shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
46
48
 
47
- const uint32_t qstride = D / 4 + 2; // in units of f16vec4
49
+ const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
48
50
  shared f16vec4 Qf[Br * qstride];
49
51
 
50
- // Avoid padding for D==256 to make it fit in 48KB shmem.
51
- const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
52
+ // Avoid padding for hsk==256 to make it fit in 48KB shmem.
53
+ const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
52
54
  shared ACC_TYPE sfsh[Bc * sfshstride];
53
55
 
54
- const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
56
+ const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
55
57
  shared f16vec4 ksh[Bc * kshstride];
56
58
 
57
59
  shared float slope[Br];
@@ -74,18 +76,18 @@ void main() {
74
76
 
75
77
  uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
76
78
 
77
- [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
78
- uint32_t d = (idx + tid) % (D / 4);
79
- uint32_t r = (idx + tid) / (D / 4);
80
- if (r < Br && d < D / 4 &&
79
+ [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
80
+ uint32_t d = (idx + tid) % (HSK / 4);
81
+ uint32_t r = (idx + tid) / (HSK / 4);
82
+ if (r < Br && d < HSK / 4 &&
81
83
  i * Br + r < N) {
82
84
  Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
83
85
  }
84
86
  }
85
87
  barrier();
86
88
 
87
- ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
88
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
89
+ ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
90
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
89
91
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
90
92
  Of[r][d] = ACC_TYPEV4(0.0);
91
93
  }
@@ -123,14 +125,18 @@ void main() {
123
125
  uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
124
126
  uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
125
127
  #endif
128
+ uint32_t m_offset = 0;
129
+ if (p.nem2 != 1 || p.nem3 != 1) {
130
+ m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
131
+ }
126
132
 
127
133
  [[dont_unroll]]
128
134
  for (uint32_t j = start_j; j < end_j; ++j) {
129
135
 
130
- [[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
131
- uint32_t d = (idx + tid) % (D / 4);
132
- uint32_t c = (idx + tid) / (D / 4);
133
- if (c < Bc && d < D / 4) {
136
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
137
+ uint32_t d = (idx + tid) % (HSK / 4);
138
+ uint32_t c = (idx + tid) / (HSK / 4);
139
+ if (c < Bc && d < HSK / 4) {
134
140
  #if BLOCK_SIZE > 1
135
141
  uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
136
142
  uint ib = coord / BLOCK_SIZE;
@@ -145,14 +151,14 @@ void main() {
145
151
  }
146
152
  barrier();
147
153
 
148
- // K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
149
- // Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
154
+ // K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
155
+ // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
150
156
  // This is written transposed in order to allow for N being 8 if implementations need it
151
157
  coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
152
158
  coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
153
159
  coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
154
160
 
155
- for (uint32_t d = 0; d < D / 16; ++d) {
161
+ for (uint32_t d = 0; d < HSK / 16; ++d) {
156
162
  coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
157
163
 
158
164
  uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
@@ -176,12 +182,12 @@ void main() {
176
182
  barrier();
177
183
  }
178
184
 
179
- if (p.mask != 0) {
185
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
180
186
  [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
181
187
  uint32_t c = (idx + tid) % Bc;
182
188
  uint32_t r = (idx + tid) / Bc;
183
189
  if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
184
- sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
190
+ sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
185
191
  }
186
192
  }
187
193
  barrier();
@@ -202,7 +208,7 @@ void main() {
202
208
  eMf[r] = exp(Moldf - Mf[r]);
203
209
  }
204
210
 
205
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
211
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
206
212
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
207
213
  Of[r][d] = float16_t(eMf[r]) * Of[r][d];
208
214
  }
@@ -217,7 +223,7 @@ void main() {
217
223
  Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
218
224
  Lf[r] += Pf[r];
219
225
  }
220
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
226
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
221
227
  #if BLOCK_SIZE > 1
222
228
  uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
223
229
  uint ib = coord / BLOCK_SIZE;
@@ -280,7 +286,7 @@ void main() {
280
286
  }
281
287
 
282
288
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
283
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
289
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
284
290
 
285
291
  Of[r][d] = float16_t(eMf[r]) * Of[r][d];
286
292
  tmpshv4[tid] = Of[r][d];
@@ -300,11 +306,11 @@ void main() {
300
306
  // If there is split_k, then the split_k resolve shader does the final
301
307
  // division by L. Store the intermediate O value and per-row m and L values.
302
308
  if (p.k_num > 1) {
303
- uint32_t o_offset = D * p.ne1 * split_k_index;
309
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
304
310
 
305
311
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
306
312
  if (tile_row(r) < N) {
307
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
313
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
308
314
  [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
309
315
  perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
310
316
  }
@@ -312,7 +318,7 @@ void main() {
312
318
  }
313
319
  }
314
320
 
315
- o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
321
+ o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
316
322
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
317
323
  if (tile_row(r) < N) {
318
324
  perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -328,18 +334,18 @@ void main() {
328
334
  Lfrcp[r] = 1.0 / Lf[r];
329
335
  }
330
336
 
331
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
337
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
332
338
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
333
339
  Of[r][d] *= float16_t(Lfrcp[r]);
334
340
  }
335
341
  }
336
342
 
337
- uint32_t o_offset = iq3*p.ne2*p.ne1;
343
+ uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
338
344
 
339
345
  if (p.gqa_ratio > 1) {
340
346
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
341
347
  if (tile_row(r) < N) {
342
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
348
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
343
349
  [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
344
350
  perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
345
351
  }
@@ -349,9 +355,9 @@ void main() {
349
355
  } else {
350
356
  [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
351
357
  if (i * Br + tile_row(r) < N) {
352
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
358
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
353
359
  [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
354
- data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
360
+ data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
355
361
  }
356
362
  }
357
363
  }