@novastera-oss/llamarn 0.2.9 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (314) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/proguard-rules.pro +12 -0
  3. package/android/src/main/cpp/include/llama.h +15 -47
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +0 -1
  22. package/cpp/llama.cpp/CMakePresets.json +11 -0
  23. package/cpp/llama.cpp/CODEOWNERS +1 -0
  24. package/cpp/llama.cpp/README.md +8 -8
  25. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  26. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  27. package/cpp/llama.cpp/common/arg.cpp +62 -1
  28. package/cpp/llama.cpp/common/chat.cpp +37 -20
  29. package/cpp/llama.cpp/common/chat.h +2 -0
  30. package/cpp/llama.cpp/common/common.cpp +22 -6
  31. package/cpp/llama.cpp/common/common.h +22 -4
  32. package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
  33. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
  34. package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
  35. package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
  36. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  37. package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
  38. package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
  39. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  40. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
  41. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
  42. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
  43. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  44. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
  45. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
  46. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
  61. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
  62. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
  63. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  64. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
  65. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
  66. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
  95. package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  97. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
  98. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
  99. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
  100. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
  101. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
  102. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
  103. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  104. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  105. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  131. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
  132. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
  133. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
  134. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  135. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  136. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  137. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  138. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  139. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  140. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  141. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  142. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  143. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  144. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
  173. package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
  177. package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
  178. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
  179. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
  180. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
  181. package/cpp/llama.cpp/include/llama.h +15 -47
  182. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
  183. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
  184. package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
  185. package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
  186. package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
  187. package/cpp/llama.cpp/src/llama-arch.h +23 -1
  188. package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
  189. package/cpp/llama.cpp/src/llama-batch.h +31 -18
  190. package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
  191. package/cpp/llama.cpp/src/llama-chat.h +3 -0
  192. package/cpp/llama.cpp/src/llama-context.cpp +180 -106
  193. package/cpp/llama.cpp/src/llama-context.h +26 -16
  194. package/cpp/llama.cpp/src/llama-cparams.h +3 -2
  195. package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
  196. package/cpp/llama.cpp/src/llama-graph.h +184 -122
  197. package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
  198. package/cpp/llama.cpp/src/llama-hparams.h +13 -2
  199. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
  200. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
  201. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
  202. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
  203. package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
  204. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
  205. package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
  206. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
  207. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  208. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  209. package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
  210. package/cpp/llama.cpp/src/llama-model.h +21 -4
  211. package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
  212. package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
  213. package/cpp/llama.cpp/src/llama-vocab.h +43 -0
  214. package/cpp/llama.cpp/src/unicode.cpp +207 -0
  215. package/cpp/llama.cpp/src/unicode.h +2 -0
  216. package/ios/include/chat.h +2 -0
  217. package/ios/include/common.h +22 -4
  218. package/ios/include/llama.h +15 -47
  219. package/ios/libs/llama.xcframework/Info.plist +13 -13
  220. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  221. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  223. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
  224. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
  225. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  231. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  232. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  248. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
  250. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  254. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  255. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  261. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  262. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
  263. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  264. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
  265. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  267. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  268. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
  269. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
  270. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  274. package/package.json +4 -4
  275. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  276. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  277. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  278. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  279. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  280. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -18,6 +18,7 @@
18
18
  #extension GL_KHR_cooperative_matrix : enable
19
19
  #extension GL_KHR_memory_scope_semantics : enable
20
20
  #extension GL_KHR_shader_subgroup_basic : enable
21
+ #extension GL_KHR_shader_subgroup_ballot : enable
21
22
  #endif
22
23
 
23
24
  #ifdef MUL_MAT_ID
@@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
104
105
 
105
106
  #ifdef MUL_MAT_ID
106
107
  shared u16vec2 row_ids[4096];
108
+ uint _ne1;
109
+ #ifdef COOPMAT
110
+ shared uint _ne1_sh;
111
+ #endif
107
112
  #endif // MUL_MAT_ID
108
113
 
109
114
  #define NUM_WARPS (BLOCK_SIZE / WARP)
@@ -172,7 +177,47 @@ void main() {
172
177
  const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
173
178
 
174
179
  #ifdef MUL_MAT_ID
175
- uint _ne1 = 0;
180
+ #ifdef COOPMAT
181
+ // Spread the search across all elements in the first subgroup
182
+ if (gl_SubgroupID == 0) {
183
+ _ne1 = 0;
184
+ uint num_elements = p.nei1 * p.nei0;
185
+
186
+ uint ids[16];
187
+ uint iter = 0;
188
+
189
+ for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
190
+ // prefetch up to 16 elements
191
+ if (iter == 0) {
192
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
193
+ uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
194
+ bool in_range = i < num_elements;
195
+ uint ii1 = i / p.nei0;
196
+ uint ii0 = i % p.nei0;
197
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
198
+ }
199
+ }
200
+ uint i = j + gl_SubgroupInvocationID;
201
+ bool in_range = i < num_elements;
202
+ uint ii1 = i / p.nei0;
203
+ uint ii0 = i % p.nei0;
204
+ uint id = ids[iter++];
205
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
206
+ uint idx = subgroupBallotExclusiveBitCount(ballot);
207
+ if (in_range && id == expert_idx) {
208
+ row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
209
+ }
210
+ _ne1 += subgroupBallotBitCount(ballot);
211
+ iter &= 15;
212
+ }
213
+ _ne1_sh = _ne1;
214
+ }
215
+
216
+ barrier();
217
+
218
+ _ne1 = _ne1_sh;
219
+ #else
220
+ _ne1 = 0;
176
221
  for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
177
222
  for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
178
223
  if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
@@ -183,6 +228,7 @@ void main() {
183
228
  }
184
229
 
185
230
  barrier();
231
+ #endif
186
232
 
187
233
  // Workgroup has no work
188
234
  if (ic * BN >= _ne1) return;
@@ -500,10 +546,9 @@ void main() {
500
546
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
501
547
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
502
548
 
503
- const uint ib = idx / 128; // 2 values per idx
504
- const uint ib32 = (idx % 128) / 16; // 0..7
505
- const uint ib8 = (idx % 128) / 4;
506
- const int i8 = 2 * int(idx % 4);
549
+ const uint ib = idx / 32; // 8 values per idx
550
+ const uint ib32 = (idx % 32) / 4; // 0..7
551
+ const uint ib8 = idx % 32;
507
552
 
508
553
  const float d = float(data_a[ib].d);
509
554
  const uint qh = data_a[ib].qh[ib32];
@@ -512,22 +557,16 @@ void main() {
512
557
  const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
513
558
  const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
514
559
 
515
- const ivec2 gvec = ivec2(
516
- bitfieldExtract(grid, 2 * (i8), 2),
517
- bitfieldExtract(grid, 2 * (i8 + 1), 2)
518
- );
519
- const vec2 v = dl * (vec2(gvec) + delta);
520
-
521
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
522
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
560
+ [[unroll]] for (int k = 0; k < 8; ++k) {
561
+ buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
562
+ }
523
563
  #elif defined(DATA_A_IQ1_M)
524
564
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
525
565
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
526
566
 
527
- const uint ib = idx / 128; // 2 values per idx
528
- const uint ib8 = (idx % 128) / 4;
567
+ const uint ib = idx / 32; // 8 values per idx
568
+ const uint ib8 = idx % 32;
529
569
  const uint ib16 = ib8 / 2;
530
- const int i8 = 2 * int(idx % 4);
531
570
 
532
571
  const uint16_t[4] scales = data_a[ib].scales;
533
572
  const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@@ -538,21 +577,17 @@ void main() {
538
577
  const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
539
578
  const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
540
579
  const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
541
- const ivec2 gvec = ivec2(
542
- bitfieldExtract(grid, 2 * (i8), 2),
543
- bitfieldExtract(grid, 2 * (i8 + 1), 2)
544
- );
545
- const vec2 v = dl * (vec2(gvec) + delta);
546
580
 
547
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
548
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
581
+ [[unroll]] for (int k = 0; k < 8; ++k) {
582
+ buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
583
+ }
549
584
  #elif defined(DATA_A_IQ2_XXS)
550
585
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
551
586
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
552
587
 
553
- const uint ib = idx / 128; // 2 values per idx
554
- const uint ib32 = (idx % 128) / 16; // 0..7
555
- const uint ib8 = (idx / 4) % 4;
588
+ const uint ib = idx / 32; // 8 values per idx
589
+ const uint ib32 = (idx % 32) / 4; // 0..7
590
+ const uint ib8 = idx % 4;
556
591
 
557
592
  const float d = float(data_a[ib].d);
558
593
  const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@@ -562,63 +597,81 @@ void main() {
562
597
  data_a[ib].qs[8*ib32 + 6],
563
598
  data_a[ib].qs[8*ib32 + 7]
564
599
  ));
565
- const float db = d * 0.25 * (0.5 + (signs >> 28));
600
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
566
601
  const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
567
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
568
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
569
- const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
570
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
571
-
572
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
573
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
602
+ const uint sign = sign7 | (bitCount(sign7) << 7);
603
+ const uvec2 grid = iq2xxs_grid[qs];
604
+ const vec4 grid0 = vec4(unpack8(grid.x));
605
+ const vec4 grid1 = vec4(unpack8(grid.y));
606
+
607
+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
608
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
609
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
610
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
611
+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
612
+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
613
+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
614
+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
574
615
  #elif defined(DATA_A_IQ2_XS)
575
616
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
576
617
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
577
618
 
578
- const uint ib = idx / 128; // 2 values per idx
579
- const uint ib32 = (idx % 128) / 16; // 0..7
580
- const uint ib8 = (idx / 4) % 4; // 0..3
619
+ const uint ib = idx / 32; // 8 values per idx
620
+ const uint ib32 = (idx % 32) / 4; // 0..7
621
+ const uint ib8 = idx % 4; // 0..3
581
622
 
582
623
  const float d = float(data_a[ib].d);
583
624
  const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
584
- const float db = d * 0.25 * (0.5 + scale);
625
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
585
626
  const uint qs = data_a[ib].qs[4 * ib32 + ib8];
586
627
  const uint sign7 = qs >> 9;
587
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
588
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
589
- const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
590
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
591
-
592
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
593
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
628
+ const uint sign = sign7 | (bitCount(sign7) << 7);
629
+ const uvec2 grid = iq2xs_grid[qs & 511];
630
+ const vec4 grid0 = vec4(unpack8(grid.x));
631
+ const vec4 grid1 = vec4(unpack8(grid.y));
632
+
633
+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
634
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
635
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
636
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
637
+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
638
+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
639
+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
640
+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
594
641
  #elif defined(DATA_A_IQ2_S)
595
642
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
596
643
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
597
644
 
598
- const uint ib = idx / 128; // 2 values per idx
599
- const uint ib8 = (idx % 128) / 4; // 0..31
600
- const uint ib32 = ib8 / 4; // 0..7
645
+ const uint ib = idx / 32; // 8 values per idx
646
+ const uint ib8 = idx % 32; // 0..31
647
+ const uint ib32 = ib8 / 4; // 0..7
601
648
 
602
649
  const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
603
650
  const uint qs = data_a[ib].qs[ib8];
604
651
  const uint qh = data_a[ib].qh[ib32];
605
652
  const uint qhshift = 2 * (ib8 % 4);
606
- const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
653
+ const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
607
654
 
608
655
  const float d = float(data_a[ib].d);
609
- const float db = d * 0.25 * (0.5 + scale);
610
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
611
- const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
612
- const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
613
-
614
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
615
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
656
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
657
+ const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
658
+ const vec4 grid0 = vec4(unpack8(grid.x));
659
+ const vec4 grid1 = vec4(unpack8(grid.y));
660
+
661
+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
662
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
663
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
664
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
665
+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
666
+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
667
+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
668
+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
616
669
  #elif defined(DATA_A_IQ3_XXS)
617
670
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
618
671
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
619
672
 
620
- const uint ib = idx / 128; // 2 values per idx
621
- const uint iqs = (idx % 128) / 2; // 0..63
673
+ const uint ib = idx / 64; // 4 values per idx
674
+ const uint iqs = idx % 64; // 0..63
622
675
  const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
623
676
 
624
677
  const float d = float(data_a[ib].d);
@@ -631,33 +684,36 @@ void main() {
631
684
  ));
632
685
  const float db = d * 0.5 * (0.5 + (signs >> 28));
633
686
  const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
634
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
635
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
636
- const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
637
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
638
-
639
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
640
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
687
+ const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
688
+ const uint grid = iq3xxs_grid[qs];
689
+ const vec4 v = db * vec4(unpack8(grid));
690
+
691
+ buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
692
+ buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
693
+ buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
694
+ buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
641
695
  #elif defined(DATA_A_IQ3_S)
642
696
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
643
697
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
644
698
 
645
- const uint ib = idx / 128; // 2 values per idx
646
- const uint iqs = (idx % 128) / 2; // 0..63
699
+ const uint ib = idx / 64; // 4 values per idx
700
+ const uint iqs = idx % 64; // 0..63
647
701
  const uint iqh = iqs / 8;
648
702
 
649
703
  const float d = float(data_a[ib].d);
650
704
  const uint qs = data_a[ib].qs[iqs];
651
705
  const uint qh = data_a[ib].qh[iqh];
652
- const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
706
+ const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
653
707
  const uint scale = data_a[ib].scales[iqs / 16];
654
708
  const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
655
709
  const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
656
- const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
657
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
710
+ const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
711
+ const vec4 v = db * vec4(unpack8(grid));
658
712
 
659
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
660
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
713
+ buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
714
+ buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
715
+ buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
716
+ buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
661
717
  #elif defined(DATA_A_IQ4_XS)
662
718
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
663
719
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
@@ -162,17 +162,32 @@ void main() {
162
162
  _ne1 = 0;
163
163
  uint num_elements = p.nei1 * p.nei0;
164
164
 
165
- for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
165
+ uint ids[16];
166
+ uint iter = 0;
167
+
168
+ for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
169
+ // prefetch up to 16 elements
170
+ if (iter == 0) {
171
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
172
+ uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
173
+ bool in_range = i < num_elements;
174
+ uint ii1 = i / p.nei0;
175
+ uint ii0 = i % p.nei0;
176
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
177
+ }
178
+ }
179
+ uint i = j + gl_SubgroupInvocationID;
166
180
  bool in_range = i < num_elements;
167
- uint ii0 = i % p.nei0;
168
181
  uint ii1 = i / p.nei0;
169
- uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
182
+ uint ii0 = i % p.nei0;
183
+ uint id = ids[iter++];
170
184
  uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
171
185
  uint idx = subgroupBallotExclusiveBitCount(ballot);
172
186
  if (in_range && id == expert_idx) {
173
187
  row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
174
188
  }
175
189
  _ne1 += subgroupBallotBitCount(ballot);
190
+ iter &= 15;
176
191
  }
177
192
  _ne1_sh = _ne1;
178
193
  }
@@ -414,17 +429,31 @@ void main() {
414
429
  fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
415
430
  }
416
431
 
417
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
418
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
432
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
433
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
434
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
419
435
 
420
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
436
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
421
437
  #ifdef MUL_MAT_ID
422
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
438
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
423
439
  #else
424
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
440
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
425
441
  #endif
426
442
 
427
- sum = coopMatMulAdd(mat_a, mat_b, sum);
443
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
444
+ } else {
445
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
446
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
447
+
448
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
449
+ #ifdef MUL_MAT_ID
450
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
451
+ #else
452
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
453
+ #endif
454
+
455
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
456
+ }
428
457
  }
429
458
 
430
459
  // Convert from ACC_TYPE to D_TYPE
@@ -0,0 +1,9 @@
1
+ #version 450
2
+
3
+ #include "glu_head.comp"
4
+
5
+ float op(float a, float b) {
6
+ return max(a, 0.0f) * b;
7
+ }
8
+
9
+ #include "glu_main.comp"
@@ -1,11 +1,13 @@
1
1
  #version 450
2
2
 
3
- #include "generic_unary_head.comp"
3
+ #include "generic_binary_head.comp"
4
4
  #include "types.comp"
5
5
 
6
6
  #extension GL_EXT_control_flow_attributes : enable
7
7
  #define BLOCK_SIZE 512
8
8
 
9
+ layout (constant_id = 1) const bool do_multiply = false;
10
+
9
11
  layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
10
12
 
11
13
  shared FLOAT_TYPE sum[BLOCK_SIZE];
@@ -25,6 +27,7 @@ void main() {
25
27
  const uint stride_sample = p.nb03;
26
28
 
27
29
  uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
30
+ uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
28
31
  uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
29
32
 
30
33
  sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
@@ -46,7 +49,19 @@ void main() {
46
49
  const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
47
50
  const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
48
51
 
49
- [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
50
- data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
52
+ if (do_multiply) {
53
+ if (ncols > p.ne10) {
54
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
55
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
56
+ }
57
+ } else {
58
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
59
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
60
+ }
61
+ }
62
+ } else {
63
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
64
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
65
+ }
51
66
  }
52
67
  }
@@ -0,0 +1,46 @@
1
+ #version 450
2
+
3
+ #include "types.comp"
4
+ #include "generic_unary_head.comp"
5
+
6
+ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7
+
8
+ uint wrap_idx(int i, uint ne) {
9
+ if (i < 0) {
10
+ return i + ne;
11
+ } else if (i >= ne) {
12
+ return i - ne;
13
+ }
14
+ return i;
15
+ }
16
+
17
+ void main() {
18
+ const uint idx = get_idx();
19
+ if (idx >= p.ne) {
20
+ return;
21
+ }
22
+
23
+ const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
24
+ const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
25
+ const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);
26
+ const uint i2_offset = i2*p.ne11*p.ne10;
27
+ const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);
28
+ const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
29
+
30
+ const uint p1 = floatBitsToUint(p.param1);
31
+ const uint p2 = floatBitsToUint(p.param2);
32
+ const int s0 = int(p1 >> 16) - 0x8000;
33
+ const int s1 = int(p1 & 0xFFFF) - 0x8000;
34
+ const int s2 = int(p2 >> 16) - 0x8000;
35
+ const int s3 = int(p2 & 0xFFFF) - 0x8000;
36
+
37
+ const uint i00 = wrap_idx(int(i0) - s0, p.ne10);
38
+ const uint i01 = wrap_idx(int(i1) - s1, p.ne11);
39
+ const uint i02 = wrap_idx(int(i2) - s2, p.ne12);
40
+ const uint i03 = wrap_idx(int(i3) - s3, p.ne13);
41
+
42
+ const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
43
+ const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;
44
+
45
+ data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);
46
+ }
@@ -1,11 +1,8 @@
1
1
  #include "types.comp"
2
2
 
3
3
  #extension GL_EXT_shader_16bit_storage : require
4
- #extension GL_EXT_spirv_intrinsics: enable
5
4
 
6
- #if RTE16
7
- spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
8
- #endif
5
+ #include "rte.comp"
9
6
 
10
7
  layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
11
8
 
@@ -14,21 +14,19 @@ void main() {
14
14
 
15
15
  const uint row_dst = gl_GlobalInvocationID.x;
16
16
 
17
- if (i0 >= p.n_dims) {
18
- const uint i = row_dst*ne0 + i0;
19
-
20
- data_d[i + 0] = data_a[i + 0];
21
- data_d[i + 1] = data_a[i + 1];
22
-
23
- return;
24
- }
25
-
26
17
  const uint row_x = row_dst % ne1;
27
18
  const uint channel_x = row_dst / ne1;
28
19
 
29
20
  const uint idst = row_dst*ne0 + i0/2;
30
21
  const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
31
22
 
23
+ if (i0 >= p.n_dims) {
24
+ data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
25
+ data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
26
+
27
+ return;
28
+ }
29
+
32
30
  const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
33
31
  const int sec_w = p.sections[1] + p.sections[0];
34
32
  const uint sector = (i0 / 2) % sect_dims;
@@ -13,21 +13,19 @@ void main() {
13
13
 
14
14
  const uint row_dst = gl_GlobalInvocationID.x;
15
15
 
16
- if (i0 >= p.n_dims) {
17
- const uint i = row_dst*ne0 + i0;
18
-
19
- data_d[i + 0] = data_a[i + 0];
20
- data_d[i + 1] = data_a[i + 1];
21
-
22
- return;
23
- }
24
-
25
16
  const uint row_x = row_dst % ne1;
26
17
  const uint channel_x = row_dst / ne1;
27
18
 
28
19
  const uint idst = row_dst*ne0 + i0/2;
29
20
  const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
30
21
 
22
+ if (i0 >= p.n_dims) {
23
+ data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
24
+ data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
25
+
26
+ return;
27
+ }
28
+
31
29
  const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
32
30
 
33
31
  const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
@@ -13,21 +13,19 @@ void main() {
13
13
 
14
14
  const uint row_dst = gl_GlobalInvocationID.x;
15
15
 
16
- if (i0 >= p.n_dims) {
17
- const uint i = row_dst*ne0 + i0;
18
-
19
- data_d[i + 0] = data_a[i + 0];
20
- data_d[i + 1] = data_a[i + 1];
21
-
22
- return;
23
- }
24
-
25
16
  const uint row_x = row_dst % ne1;
26
17
  const uint channel_x = row_dst / ne1;
27
18
 
28
19
  const uint idst = row_dst*ne0 + i0;
29
20
  const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
30
21
 
22
+ if (i0 >= p.n_dims) {
23
+ data_d[idst + 0] = data_a[ix + 0];
24
+ data_d[idst + 1] = data_a[ix + 1];
25
+
26
+ return;
27
+ }
28
+
31
29
  const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
32
30
 
33
31
  const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
@@ -0,0 +1,5 @@
1
+
2
+ #if RTE16
3
+ #extension GL_EXT_spirv_intrinsics : enable
4
+ spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
5
+ #endif // RTE16
@@ -18,7 +18,7 @@ void main() {
18
18
  continue;
19
19
  }
20
20
 
21
- data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
21
+ data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2));
22
22
  idx += num_threads;
23
23
  }
24
24
  }