@novastera-oss/llamarn 0.2.9 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  22. package/cpp/llama.cpp/CMakePresets.json +11 -0
  23. package/cpp/llama.cpp/CODEOWNERS +1 -0
  24. package/cpp/llama.cpp/README.md +8 -8
  25. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  26. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  27. package/cpp/llama.cpp/common/arg.cpp +62 -1
  28. package/cpp/llama.cpp/common/chat.cpp +37 -20
  29. package/cpp/llama.cpp/common/chat.h +2 -0
  30. package/cpp/llama.cpp/common/common.cpp +22 -6
  31. package/cpp/llama.cpp/common/common.h +22 -4
  32. package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
  33. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
  34. package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
  35. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  36. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  37. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  38. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  41. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
  42. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  46. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  97. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
  99. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
  100. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  131. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
  132. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  138. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  139. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  140. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  142. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  143. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  144. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  173. package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  177. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  178. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  179. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
  180. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  181. package/cpp/llama.cpp/include/llama.h +15 -47
  182. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  183. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  184. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  185. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  186. package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
  187. package/cpp/llama.cpp/src/llama-arch.h +23 -1
  188. package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
  189. package/cpp/llama.cpp/src/llama-batch.h +31 -18
  190. package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
  191. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  192. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  193. package/cpp/llama.cpp/src/llama-context.h +26 -16
  194. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  195. package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
  196. package/cpp/llama.cpp/src/llama-graph.h +184 -122
  197. package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
  198. package/cpp/llama.cpp/src/llama-hparams.h +13 -2
  199. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  200. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  201. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  202. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  203. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  204. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  205. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  206. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
  207. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  208. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  209. package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
  210. package/cpp/llama.cpp/src/llama-model.h +21 -4
  211. package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
  212. package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
  213. package/cpp/llama.cpp/src/llama-vocab.h +43 -0
  214. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  215. package/cpp/llama.cpp/src/unicode.h +2 -0
  216. package/ios/include/chat.h +2 -0
  217. package/ios/include/common.h +22 -4
  218. package/ios/include/llama.h +15 -47
  219. package/ios/libs/llama.xcframework/Info.plist +13 -13
  220. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  221. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  223. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  224. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
  225. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  231. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  232. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  250. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  254. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  255. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  261. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  262. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
  263. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  264. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  265. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  267. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  268. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
  269. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
  270. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  274. package/package.json +4 -4
  275. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  276. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  277. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  278. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  279. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  280. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -0,0 +1,265 @@
1
+ #version 450
2
+
3
+ #ifdef USE_COLLECTIVES
4
+ # extension GL_KHR_shader_subgroup_shuffle : enable
5
+ #endif
6
+
7
+ #include "types.comp"
8
+
9
+ // Make spec constant
10
+ #define SHMEM_PAD 0
11
+
12
+ // shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
13
+ layout(binding = 0) readonly buffer A {
14
+ A_TYPE knl_data[];
15
+ }; // src0 - kernel: [KW, KH, Cin, Cout]
16
+
17
+ layout(binding = 1) readonly buffer B {
18
+ B_TYPE src_data[];
19
+ }; // src1 - input: [W, H, Cin, N] -- channel_first format
20
+
21
+ layout(binding = 2) writeonly buffer D {
22
+ D_TYPE dst_data[];
23
+ }; // dst - result: [OW, OH, Cout, N]
24
+
25
+ layout(push_constant) uniform parameter {
26
+ // I/O channels, batch size
27
+ uint32_t Cout;
28
+ uint32_t Cin;
29
+ uint32_t N;
30
+
31
+ // Tensor spatial sizes: kernel, input, output
32
+ uint32_t KW;
33
+ uint32_t KH;
34
+ uint32_t W;
35
+ uint32_t H;
36
+ uint32_t OW;
37
+ uint32_t OH;
38
+
39
+ // Parameters: stride, padding, dilation - 0=y, 1=x
40
+ uint32_t s0;
41
+ uint32_t s1;
42
+ uint32_t p0;
43
+ uint32_t p1;
44
+ uint32_t d0;
45
+ uint32_t d1;
46
+
47
+ // Strides in elements
48
+ uint32_t nb01;
49
+ uint32_t nb02;
50
+ uint32_t nb03;
51
+
52
+ uint32_t nb11;
53
+ uint32_t nb12;
54
+ uint32_t nb13;
55
+
56
+ uint32_t nb1;
57
+ uint32_t nb2;
58
+ uint32_t nb3;
59
+ }
60
+
61
+ p;
62
+
63
+ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
64
+ // Blocktile sizes
65
+ layout(constant_id = 1) const uint BS_K = 128;
66
+ layout(constant_id = 2) const uint BS_CRS = 16;
67
+ layout(constant_id = 3) const uint BS_NPQ = 128;
68
+ // Thread-tile sizes
69
+ layout(constant_id = 4) const uint TS_K = 8;
70
+ layout(constant_id = 5) const uint use_collectives = 1;
71
+
72
+ uint32_t tid = gl_LocalInvocationID.x;
73
+ const uint32_t WG_SIZE = gl_WorkGroupSize.x;
74
+
75
+ uint splitWork(uint work_size, uint block_size) {
76
+ return (block_size + work_size - 1) / block_size;
77
+ }
78
+
79
+ uint32_t K = p.Cout;
80
+ uint32_t CRS = p.Cin * p.KH * p.KW;
81
+ uint32_t NPQ = p.N * p.OH * p.OW;
82
+
83
+ uint32_t n_elems_out = K * NPQ;
84
+
85
+ // Number of blocktiles per input
86
+ uint32_t NB_CRS = splitWork(CRS, BS_CRS);
87
+
88
+ const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
89
+ const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
90
+
91
+ const uint32_t Ash_numel = BS_K * BS_CRS;
92
+ const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
93
+
94
+ const uint32_t Ash_len = BS_K * Ash_stride;
95
+ const uint32_t Bsh_len = BS_CRS * Bsh_stride;
96
+
97
+ shared float Ash[Ash_len]; // K x CRS
98
+ shared float Bsh[Bsh_len]; // CRS x NPQ
99
+
100
+ // Threadtile sizes
101
+ const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
102
+
103
+ // Number of threadtiles per blocktile
104
+ const uint32_t NT_K = BS_K / TS_K;
105
+ const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
106
+
107
+ float regA[TS_K];
108
+ float regB[TS_NPQ];
109
+ float regC[TS_K][TS_NPQ];
110
+
111
+ /*
112
+ Compute
113
+ KxCRS @ CRSxNPQ = K x NPQ
114
+ K=Cout
115
+ C=Cin
116
+ R,S=KH,KW
117
+ P,Q=OH,OW
118
+ */
119
+
120
+ uint32_t B_idx_K = gl_WorkGroupID.x;
121
+ uint32_t B_idx_NPQ = gl_WorkGroupID.y;
122
+
123
+ uint32_t T_y = tid / NT_NPQ;
124
+ uint32_t T_x = tid % NT_NPQ;
125
+
126
+ uint32_t Ar = tid / BS_CRS;
127
+ uint32_t Ac = tid % BS_CRS;
128
+ const uint32_t ArpWg = WG_SIZE / BS_CRS;
129
+
130
+ uint32_t Br = tid / BS_NPQ;
131
+ uint32_t Bc = tid % BS_NPQ;
132
+ const uint32_t BrpWg = WG_SIZE / BS_NPQ;
133
+
134
+ void main() {
135
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
136
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
137
+ regC[T_ly][T_lx] = 0.0;
138
+ }
139
+ }
140
+ /* Advance block in CRS dim */
141
+ for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
142
+ uint32_t CRS_idx_a;
143
+ uint32_t Cin_idx_a;
144
+ uint32_t KH_idx_a;
145
+ uint32_t KW_idx_a;
146
+
147
+ #ifdef USE_COLLECTIVES
148
+ uint32_t cached_CRS_idx;
149
+ uint32_t cached_Cin_idx;
150
+ uint32_t cached_KH_idx;
151
+ uint32_t cached_KW_idx;
152
+ if (use_collectives == 1) {
153
+ cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
154
+ cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH);
155
+ uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
156
+ cached_KH_idx = cached_CRS_remainder / p.KW;
157
+ cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
158
+
159
+ CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
160
+ Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
161
+ KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
162
+ KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
163
+ } else {
164
+ CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
165
+ Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
166
+ uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
167
+ KH_idx_a = CRS_remainder / p.KW;
168
+ KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
169
+ }
170
+ #else
171
+ CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
172
+ Cin_idx_a = CRS_idx_a / (p.KW * p.KH);
173
+ CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
174
+ KH_idx_a = CRS_remainder / p.KW;
175
+ KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
176
+ #endif
177
+
178
+ /* Load kernel to A_block: (BS_K x BS_CRS)*/
179
+ for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
180
+ uint32_t B_ly = r_offset + Ar;
181
+ uint32_t B_lx = Ac;
182
+ uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
183
+ uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
184
+ float val = knl_data[knl_idx];
185
+ if (K_idx >= K || CRS_idx_a >= CRS) {
186
+ val = 0.0;
187
+ }
188
+ Ash[B_ly * Ash_stride + B_lx] = val;
189
+ }
190
+ /* Load input to B_block: (BS_CRS x BS_NPQ) */
191
+ for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
192
+ uint32_t B_ly = r_offset + Br; /* Row index of B block */
193
+ uint32_t B_lx = Bc;
194
+ uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
195
+ uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
196
+ uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
197
+ uint32_t OH_idx = NPQ_remainder / p.OW;
198
+ uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
199
+
200
+ uint32_t CRS_idx_b;
201
+ uint32_t Cin_idx_b;
202
+ uint32_t KH_idx_b;
203
+ uint32_t KW_idx_b;
204
+ #ifdef USE_COLLECTIVES
205
+ if (use_collectives == 1) {
206
+ CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
207
+ Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
208
+ KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
209
+ KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
210
+ } else {
211
+ CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
212
+ Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
213
+ uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
214
+ KH_idx_b = CRS_remainder / p.KW;
215
+ KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
216
+ }
217
+ #else
218
+ CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
219
+ Cin_idx_b = CRS_idx_b / (p.KW * p.KH);
220
+ uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
221
+ KH_idx_b = CRS_remainder / p.KW;
222
+ KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
223
+ #endif
224
+
225
+ uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
226
+ uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
227
+ uint32_t src_idx =
228
+ min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
229
+ float val = src_data[src_idx];
230
+ if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
231
+ val = 0.0;
232
+ }
233
+ Bsh[B_ly * Bsh_stride + B_lx] = val;
234
+ }
235
+ barrier();
236
+ for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
237
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
238
+ regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
239
+ }
240
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
241
+ regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
242
+ }
243
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
244
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
245
+ regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
246
+ }
247
+ }
248
+ }
249
+ barrier();
250
+ }
251
+ /* Save C* */
252
+ for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
253
+ for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
254
+ uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
255
+ uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
256
+ uint32_t N_idx = NPQ_idx / (p.OH * p.OW);
257
+ uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW;
258
+ uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
259
+ uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
260
+ if (K_idx < K && NPQ_idx < NPQ) {
261
+ dst_data[dst_idx] = regC[T_ly][T_lx];
262
+ }
263
+ }
264
+ }
265
+ }
@@ -1,22 +1,26 @@
1
1
  #version 450
2
2
 
3
- #if RTE16
4
- #extension GL_EXT_spirv_intrinsics : enable
5
- spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
6
- #endif // RTE16
7
-
3
+ #include "rte.comp"
8
4
  #include "types.comp"
9
- #include "generic_unary_head.comp"
10
5
 
11
- #if defined(DATA_A_IQ4_NL)
12
- // 16 invocations needed for init_iq4nl_shmem
13
- layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
6
+ #if defined(SET_ROWS) && QUANT_K == 1
7
+ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
8
+ const uint BLOCK_SIZE = 512;
14
9
  #else
15
- layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
10
+ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
11
+ const uint BLOCK_SIZE = 32;
16
12
  #endif
17
13
 
18
14
  layout (binding = 0) readonly buffer S {float data_s[];};
15
+
16
+ #if defined(SET_ROWS)
17
+ #include "generic_binary_head.comp"
18
+ layout (binding = 1) readonly buffer C {uvec2 data_i[];};
19
+ layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
20
+ #else
21
+ #include "generic_unary_head.comp"
19
22
  layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
23
+ #endif
20
24
 
21
25
  #if defined(DATA_A_Q4_0)
22
26
  void quantize(uint dst_idx, uint src_idx)
@@ -221,15 +225,56 @@ void quantize(uint dst_idx, uint src_idx)
221
225
  }
222
226
  #endif
223
227
 
228
+ #if defined(DATA_A_F32) || defined(DATA_A_F16)
229
+ void quantize(uint dst_idx, uint src_idx)
230
+ {
231
+ data_q[dst_idx] = A_TYPE(data_s[src_idx]);
232
+ }
233
+ #endif
234
+
235
+ #if defined(DATA_A_BF16)
236
+ void quantize(uint dst_idx, uint src_idx)
237
+ {
238
+ data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
239
+ }
240
+ #endif
241
+
242
+ #if defined(SET_ROWS)
243
+
224
244
  void main() {
225
245
  #ifdef NEEDS_INIT_IQ_SHMEM
226
246
  init_iq_shmem(gl_WorkGroupSize);
227
- if (gl_LocalInvocationIndex.x != 0) {
247
+ #endif
248
+
249
+ const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
250
+
251
+ if (idx >= p.ne) {
228
252
  return;
229
253
  }
254
+
255
+ uint i00, i01, i02, i03;
256
+ get_indices(idx, i00, i01, i02, i03);
257
+
258
+ uint i12 = fastmod(i03, p.ne12);
259
+ uint i11 = fastmod(i02, p.ne11);
260
+ uint i10 = i01;
261
+
262
+ uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
263
+
264
+ uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
265
+ uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
266
+
267
+ quantize(dst_idx, src0_idx);
268
+ }
269
+
270
+ #else
271
+
272
+ void main() {
273
+ #ifdef NEEDS_INIT_IQ_SHMEM
274
+ init_iq_shmem(gl_WorkGroupSize);
230
275
  #endif
231
276
 
232
- const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
277
+ const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
233
278
 
234
279
  if (idx >= p.ne) {
235
280
  return;
@@ -240,3 +285,5 @@ void main() {
240
285
 
241
286
  quantize(dst_idx, src_idx);
242
287
  }
288
+
289
+ #endif
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
10
10
  void main() {
11
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
12
  const uint i = gl_WorkGroupID.x * 256 + wgy;
13
- if (i >= p.M * p.K / QUANT_K) {
13
+ if (i >= p.nel / QUANT_K) {
14
14
  return;
15
15
  }
16
16
 
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
10
10
  void main() {
11
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
12
  const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
13
- if (i >= p.M * p.K / QUANT_K) {
13
+ if (i >= p.nel / QUANT_K) {
14
14
  return;
15
15
  }
16
16
 
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
10
10
  void main() {
11
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
12
  const uint ib = gl_WorkGroupID.x * 256 + wgy;
13
- if (ib >= p.M * p.K / QUANT_K) {
13
+ if (ib >= p.nel / QUANT_K) {
14
14
  return;
15
15
  }
16
16
 
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
10
10
  void main() {
11
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
12
  const uint ib = gl_WorkGroupID.x * 256 + wgy;
13
- if (ib >= p.M * p.K / QUANT_K) {
13
+ if (ib >= p.nel / QUANT_K) {
14
14
  return;
15
15
  }
16
16
 
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
10
10
  void main() {
11
11
  [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
12
12
  const uint i = gl_WorkGroupID.x * 256 + wgy;
13
- if (i >= p.M * p.K / QUANT_K) {
13
+ if (i >= p.nel / QUANT_K) {
14
14
  return;
15
15
  }
16
16
  const uint tid = gl_LocalInvocationID.x;
@@ -11,7 +11,8 @@
11
11
  #include "types.comp"
12
12
  #include "flash_attn_base.comp"
13
13
 
14
- const uint32_t D_per_thread = D / D_split;
14
+ const uint32_t HSK_per_thread = HSK / D_split;
15
+ const uint32_t HSV_per_thread = HSV / D_split;
15
16
 
16
17
  const uint32_t cols_per_iter = WorkGroupSize / D_split;
17
18
  const uint32_t cols_per_thread = Bc / cols_per_iter;
@@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
29
30
  // Rows index by Q's dimension 2, and the first N rows are valid.
30
31
  D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
31
32
  {
32
- uint32_t offset = (iq2 + r) * D + c;
33
+ uint32_t offset = (iq2 + r) * HSV + c;
33
34
  data_o[o_offset + offset] = D_TYPE(elem);
34
35
  return elem;
35
36
  }
@@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
38
39
  shared vec4 tmpshv4[WorkGroupSize];
39
40
 
40
41
  shared float masksh[Bc][Br];
41
- shared vec4 Qf[Br][D / 4];
42
+ shared vec4 Qf[Br][HSK / 4];
42
43
 
43
44
  void main() {
44
45
  #ifdef NEEDS_INIT_IQ_SHMEM
@@ -53,18 +54,18 @@ void main() {
53
54
 
54
55
  uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
55
56
 
56
- [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
57
- uint32_t d = (idx + tid) % (D / 4);
58
- uint32_t r = (idx + tid) / (D / 4);
59
- if (r < Br && d < D / 4 &&
57
+ [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
58
+ uint32_t d = (idx + tid) % (HSK / 4);
59
+ uint32_t r = (idx + tid) / (HSK / 4);
60
+ if (r < Br && d < HSK / 4 &&
60
61
  i * Br + r < N) {
61
62
  Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
62
63
  }
63
64
  }
64
65
  barrier();
65
66
 
66
- vec4 Of[Br][D_per_thread / 4];
67
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
67
+ vec4 Of[Br][HSV_per_thread / 4];
68
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
68
69
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
69
70
  Of[r][d] = vec4(0.0);
70
71
  }
@@ -99,6 +100,10 @@ void main() {
99
100
  uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
100
101
  uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
101
102
  #endif
103
+ uint32_t m_offset = 0;
104
+ if (p.nem2 != 1 || p.nem3 != 1) {
105
+ m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
106
+ }
102
107
 
103
108
  [[dont_unroll]]
104
109
  for (uint32_t j = start_j; j < end_j; ++j) {
@@ -112,7 +117,7 @@ void main() {
112
117
 
113
118
 
114
119
  [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
115
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
120
+ [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
116
121
  #if BLOCK_SIZE > 1
117
122
  uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
118
123
  uint ib = coord / BLOCK_SIZE;
@@ -144,13 +149,13 @@ void main() {
144
149
  }
145
150
  }
146
151
 
147
- if (p.mask != 0) {
152
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
148
153
 
149
154
  [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
150
155
  uint32_t c = (idx + tid) % Bc;
151
156
  uint32_t r = (idx + tid) / Bc;
152
157
  if (idx + tid < Bc * Br) {
153
- masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
158
+ masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
154
159
  }
155
160
  }
156
161
  barrier();
@@ -191,14 +196,14 @@ void main() {
191
196
  Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
192
197
  }
193
198
 
194
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
199
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
195
200
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
196
201
  Of[r][d] = eMf[r] * Of[r][d];
197
202
  }
198
203
  }
199
204
 
200
205
  [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
201
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
206
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
202
207
  #if BLOCK_SIZE > 1
203
208
  uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
204
209
  uint ib = coord / BLOCK_SIZE;
@@ -255,7 +260,7 @@ void main() {
255
260
  Lf[r] = tmpsh[d_tid];
256
261
  barrier();
257
262
 
258
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
263
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
259
264
 
260
265
  Of[r][d] = eMf * Of[r][d];
261
266
  tmpshv4[tid] = Of[r][d];
@@ -277,11 +282,11 @@ void main() {
277
282
  // If there is split_k, then the split_k resolve shader does the final
278
283
  // division by L. Store the intermediate O value and per-row m and L values.
279
284
  if (p.k_num > 1) {
280
- uint32_t o_offset = D * p.ne1 * split_k_index;
285
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
281
286
 
282
287
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
283
288
  if (r < N) {
284
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
289
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
285
290
  [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
286
291
  perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
287
292
  }
@@ -289,7 +294,7 @@ void main() {
289
294
  }
290
295
  }
291
296
 
292
- o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
297
+ o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
293
298
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
294
299
  if (r < N) {
295
300
  perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -305,18 +310,18 @@ void main() {
305
310
  Lfrcp[r] = 1.0 / Lf[r];
306
311
  }
307
312
 
308
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
313
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
309
314
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
310
315
  Of[r][d] *= Lfrcp[r];
311
316
  }
312
317
  }
313
318
 
314
- uint32_t o_offset = iq3*p.ne2*p.ne1;
319
+ uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
315
320
 
316
321
  if (p.gqa_ratio > 1) {
317
322
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
318
323
  if (r < N) {
319
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
324
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
320
325
  [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
321
326
  perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
322
327
  }
@@ -326,9 +331,9 @@ void main() {
326
331
  } else {
327
332
  [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
328
333
  if (i * Br + r < N) {
329
- [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
334
+ [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
330
335
  [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
331
- data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
336
+ data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
332
337
  }
333
338
  }
334
339
  }
@@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
4
4
  layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
5
5
  layout (constant_id = 1) const uint32_t Br = 1;
6
6
  layout (constant_id = 2) const uint32_t Bc = 32;
7
- layout (constant_id = 3) const uint32_t D = 32;
8
- layout (constant_id = 4) const uint32_t Clamp = 0;
9
- layout (constant_id = 5) const uint32_t D_split = 16;
10
-
7
+ layout (constant_id = 3) const uint32_t HSK = 32;
8
+ layout (constant_id = 4) const uint32_t HSV = 32;
9
+ layout (constant_id = 5) const uint32_t Clamp = 0;
10
+ layout (constant_id = 6) const uint32_t D_split = 16;
11
11
 
12
12
  layout (push_constant) uniform parameter {
13
13
  uint32_t N;
@@ -24,6 +24,8 @@ layout (push_constant) uniform parameter {
24
24
  uint32_t nev2;
25
25
  uint32_t nev3;
26
26
  uint32_t nem1;
27
+ uint32_t nem2;
28
+ uint32_t nem3;
27
29
 
28
30
  uint32_t nb01;
29
31
  uint32_t nb02;
@@ -34,14 +36,12 @@ layout (push_constant) uniform parameter {
34
36
  uint32_t nb21;
35
37
  uint32_t nb22;
36
38
  uint32_t nb23;
37
- uint32_t nb31;
38
39
 
39
40
  float scale;
40
41
  float max_bias;
41
42
  float logit_softcap;
42
43
 
43
- uint32_t mask;
44
- uint32_t n_head_log2;
44
+ uint32_t mask_n_head_log2;
45
45
  float m0;
46
46
  float m1;
47
47
 
@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
50
50
  uint32_t k_num;
51
51
  } p;
52
52
 
53
+ #define MASK_ENABLE_BIT (1<<16)
54
+ #define N_LOG2_MASK 0xFFFF
55
+
53
56
  layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
54
57
 
55
58
  #if defined(A_TYPE_PACKED16)
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
100
103
  {
101
104
  const uint32_t h = iq2 + (r % p.gqa_ratio);
102
105
 
103
- const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
104
- const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
106
+ uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
107
+
108
+ const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
109
+ const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
105
110
 
106
111
  return ACC_TYPE(pow(base, ACC_TYPE(exph)));
107
112
  }