@novastera-oss/llamarn 0.2.7 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  10. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  11. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  13. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  15. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  18. package/cpp/LlamaCppModel.cpp +56 -22
  19. package/cpp/build-info.cpp +2 -2
  20. package/cpp/llama.cpp/CMakeLists.txt +1 -2
  21. package/cpp/llama.cpp/README.md +4 -5
  22. package/cpp/llama.cpp/build-xcframework.sh +1 -1
  23. package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
  24. package/cpp/llama.cpp/common/arg.cpp +24 -0
  25. package/cpp/llama.cpp/common/chat.cpp +37 -20
  26. package/cpp/llama.cpp/common/chat.h +2 -0
  27. package/cpp/llama.cpp/common/common.cpp +3 -0
  28. package/cpp/llama.cpp/common/common.h +5 -0
  29. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  30. package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
  31. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
  32. package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
  33. package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
  34. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  35. package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
  36. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
  37. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
  38. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
  39. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
  40. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  41. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  48. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  49. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  50. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  51. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  52. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  53. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  54. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  55. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  56. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  57. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
  58. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  59. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  60. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
  61. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
  62. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  63. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  64. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
  65. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
  66. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
  67. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
  68. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  69. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  70. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  71. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  72. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
  73. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
  74. package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  75. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
  76. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
  77. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
  78. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
  79. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
  80. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
  81. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
  82. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
  84. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  85. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  86. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
  87. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  88. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  89. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
  90. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
  91. package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
  92. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  93. package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
  94. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  95. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  96. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
  99. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
  100. package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
  101. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
  102. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
  103. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
  104. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
  105. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  106. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
  108. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  109. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
  110. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  111. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  112. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
  114. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
  115. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
  116. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
  117. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  118. package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
  119. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  121. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  123. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  124. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  125. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  126. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  127. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
  129. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  130. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  131. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
  132. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  133. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  134. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  135. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  136. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  137. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
  138. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
  139. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  140. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  141. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  142. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  144. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  145. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  146. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
  147. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
  149. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
  150. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
  151. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
  152. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
  153. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
  154. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
  161. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
  162. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  164. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  166. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  167. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  168. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  169. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
  170. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  172. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
  173. package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
  174. package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
  175. package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
  176. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
  177. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
  178. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
  179. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  180. package/cpp/llama.cpp/include/llama.h +8 -43
  181. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  182. package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
  183. package/cpp/llama.cpp/src/llama-arch.h +36 -1
  184. package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
  185. package/cpp/llama.cpp/src/llama-batch.h +105 -70
  186. package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
  187. package/cpp/llama.cpp/src/llama-chat.h +1 -0
  188. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  189. package/cpp/llama.cpp/src/llama-context.h +13 -13
  190. package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
  191. package/cpp/llama.cpp/src/llama-graph.h +78 -79
  192. package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
  193. package/cpp/llama.cpp/src/llama-hparams.h +11 -0
  194. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
  195. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
  196. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
  197. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
  198. package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
  199. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
  200. package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
  201. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
  202. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  203. package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
  204. package/cpp/llama.cpp/src/llama-memory.h +21 -22
  205. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  206. package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
  207. package/cpp/llama.cpp/src/llama-model.h +40 -0
  208. package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
  209. package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
  210. package/cpp/llama.cpp/src/llama-vocab.h +42 -0
  211. package/cpp/rn-utils.h +3 -0
  212. package/ios/include/chat.h +2 -0
  213. package/ios/include/common.h +5 -0
  214. package/ios/include/llama.h +8 -43
  215. package/ios/libs/llama.xcframework/Info.plist +19 -19
  216. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  217. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  218. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  219. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  220. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
  221. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
  222. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  223. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  224. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  225. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  226. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  227. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  228. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  229. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  230. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  231. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  232. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  233. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
  234. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
  235. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  236. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
  237. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
  238. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
  239. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  240. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
  241. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
  242. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  243. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
  244. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  245. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
  246. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
  247. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  248. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  249. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  250. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
  251. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  252. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  253. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
  254. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
  255. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  256. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  257. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
  258. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
  259. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  260. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  261. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  262. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  263. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  264. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  265. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
  266. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
  267. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  268. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
  269. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
  270. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  271. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  272. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
  273. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
  274. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
  275. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  276. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
  277. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
  278. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  279. package/package.json +1 -1
  280. package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
  281. package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  282. package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  283. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  284. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  285. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  286. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  287. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  288. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  289. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  290. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  291. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  292. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  293. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  294. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  295. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  296. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  297. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  298. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  299. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  300. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  301. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  302. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  303. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  304. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  305. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  306. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  307. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  308. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  309. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  310. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  311. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  312. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  313. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  314. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  315. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  316. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  317. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  318. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  319. package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
@@ -18,6 +18,7 @@
18
18
  #extension GL_KHR_cooperative_matrix : enable
19
19
  #extension GL_KHR_memory_scope_semantics : enable
20
20
  #extension GL_KHR_shader_subgroup_basic : enable
21
+ #extension GL_KHR_shader_subgroup_ballot : enable
21
22
  #endif
22
23
 
23
24
  #ifdef MUL_MAT_ID
@@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
104
105
 
105
106
  #ifdef MUL_MAT_ID
106
107
  shared u16vec2 row_ids[4096];
108
+ uint _ne1;
109
+ #ifdef COOPMAT
110
+ shared uint _ne1_sh;
111
+ #endif
107
112
  #endif // MUL_MAT_ID
108
113
 
109
114
  #define NUM_WARPS (BLOCK_SIZE / WARP)
@@ -172,7 +177,47 @@ void main() {
172
177
  const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
173
178
 
174
179
  #ifdef MUL_MAT_ID
175
- uint _ne1 = 0;
180
+ #ifdef COOPMAT
181
+ // Spread the search across all elements in the first subgroup
182
+ if (gl_SubgroupID == 0) {
183
+ _ne1 = 0;
184
+ uint num_elements = p.nei1 * p.nei0;
185
+
186
+ uint ids[16];
187
+ uint iter = 0;
188
+
189
+ for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
190
+ // prefetch up to 16 elements
191
+ if (iter == 0) {
192
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
193
+ uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
194
+ bool in_range = i < num_elements;
195
+ uint ii1 = i / p.nei0;
196
+ uint ii0 = i % p.nei0;
197
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
198
+ }
199
+ }
200
+ uint i = j + gl_SubgroupInvocationID;
201
+ bool in_range = i < num_elements;
202
+ uint ii1 = i / p.nei0;
203
+ uint ii0 = i % p.nei0;
204
+ uint id = ids[iter++];
205
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
206
+ uint idx = subgroupBallotExclusiveBitCount(ballot);
207
+ if (in_range && id == expert_idx) {
208
+ row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
209
+ }
210
+ _ne1 += subgroupBallotBitCount(ballot);
211
+ iter &= 15;
212
+ }
213
+ _ne1_sh = _ne1;
214
+ }
215
+
216
+ barrier();
217
+
218
+ _ne1 = _ne1_sh;
219
+ #else
220
+ _ne1 = 0;
176
221
  for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
177
222
  for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
178
223
  if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
@@ -183,6 +228,7 @@ void main() {
183
228
  }
184
229
 
185
230
  barrier();
231
+ #endif
186
232
 
187
233
  // Workgroup has no work
188
234
  if (ic * BN >= _ne1) return;
@@ -500,10 +546,9 @@ void main() {
500
546
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
501
547
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
502
548
 
503
- const uint ib = idx / 128; // 2 values per idx
504
- const uint ib32 = (idx % 128) / 16; // 0..7
505
- const uint ib8 = (idx % 128) / 4;
506
- const int i8 = 2 * int(idx % 4);
549
+ const uint ib = idx / 32; // 8 values per idx
550
+ const uint ib32 = (idx % 32) / 4; // 0..7
551
+ const uint ib8 = idx % 32;
507
552
 
508
553
  const float d = float(data_a[ib].d);
509
554
  const uint qh = data_a[ib].qh[ib32];
@@ -512,22 +557,16 @@ void main() {
512
557
  const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
513
558
  const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
514
559
 
515
- const ivec2 gvec = ivec2(
516
- bitfieldExtract(grid, 2 * (i8), 2),
517
- bitfieldExtract(grid, 2 * (i8 + 1), 2)
518
- );
519
- const vec2 v = dl * (vec2(gvec) + delta);
520
-
521
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
522
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
560
+ [[unroll]] for (int k = 0; k < 8; ++k) {
561
+ buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
562
+ }
523
563
  #elif defined(DATA_A_IQ1_M)
524
564
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
525
565
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
526
566
 
527
- const uint ib = idx / 128; // 2 values per idx
528
- const uint ib8 = (idx % 128) / 4;
567
+ const uint ib = idx / 32; // 8 values per idx
568
+ const uint ib8 = idx % 32;
529
569
  const uint ib16 = ib8 / 2;
530
- const int i8 = 2 * int(idx % 4);
531
570
 
532
571
  const uint16_t[4] scales = data_a[ib].scales;
533
572
  const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@@ -538,21 +577,17 @@ void main() {
538
577
  const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
539
578
  const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
540
579
  const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
541
- const ivec2 gvec = ivec2(
542
- bitfieldExtract(grid, 2 * (i8), 2),
543
- bitfieldExtract(grid, 2 * (i8 + 1), 2)
544
- );
545
- const vec2 v = dl * (vec2(gvec) + delta);
546
580
 
547
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
548
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
581
+ [[unroll]] for (int k = 0; k < 8; ++k) {
582
+ buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
583
+ }
549
584
  #elif defined(DATA_A_IQ2_XXS)
550
585
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
551
586
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
552
587
 
553
- const uint ib = idx / 128; // 2 values per idx
554
- const uint ib32 = (idx % 128) / 16; // 0..7
555
- const uint ib8 = (idx / 4) % 4;
588
+ const uint ib = idx / 32; // 8 values per idx
589
+ const uint ib32 = (idx % 32) / 4; // 0..7
590
+ const uint ib8 = idx % 4;
556
591
 
557
592
  const float d = float(data_a[ib].d);
558
593
  const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@@ -562,63 +597,81 @@ void main() {
562
597
  data_a[ib].qs[8*ib32 + 6],
563
598
  data_a[ib].qs[8*ib32 + 7]
564
599
  ));
565
- const float db = d * 0.25 * (0.5 + (signs >> 28));
600
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
566
601
  const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
567
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
568
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
569
- const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
570
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
571
-
572
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
573
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
602
+ const uint sign = sign7 | (bitCount(sign7) << 7);
603
+ const uvec2 grid = iq2xxs_grid[qs];
604
+ const vec4 grid0 = vec4(unpack8(grid.x));
605
+ const vec4 grid1 = vec4(unpack8(grid.y));
606
+
607
+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
608
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
609
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
610
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
611
+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
612
+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
613
+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
614
+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
574
615
  #elif defined(DATA_A_IQ2_XS)
575
616
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
576
617
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
577
618
 
578
- const uint ib = idx / 128; // 2 values per idx
579
- const uint ib32 = (idx % 128) / 16; // 0..7
580
- const uint ib8 = (idx / 4) % 4; // 0..3
619
+ const uint ib = idx / 32; // 8 values per idx
620
+ const uint ib32 = (idx % 32) / 4; // 0..7
621
+ const uint ib8 = idx % 4; // 0..3
581
622
 
582
623
  const float d = float(data_a[ib].d);
583
624
  const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
584
- const float db = d * 0.25 * (0.5 + scale);
625
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
585
626
  const uint qs = data_a[ib].qs[4 * ib32 + ib8];
586
627
  const uint sign7 = qs >> 9;
587
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
588
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
589
- const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
590
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
591
-
592
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
593
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
628
+ const uint sign = sign7 | (bitCount(sign7) << 7);
629
+ const uvec2 grid = iq2xs_grid[qs & 511];
630
+ const vec4 grid0 = vec4(unpack8(grid.x));
631
+ const vec4 grid1 = vec4(unpack8(grid.y));
632
+
633
+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
634
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
635
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
636
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
637
+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
638
+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
639
+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
640
+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
594
641
  #elif defined(DATA_A_IQ2_S)
595
642
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
596
643
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
597
644
 
598
- const uint ib = idx / 128; // 2 values per idx
599
- const uint ib8 = (idx % 128) / 4; // 0..31
600
- const uint ib32 = ib8 / 4; // 0..7
645
+ const uint ib = idx / 32; // 8 values per idx
646
+ const uint ib8 = idx % 32; // 0..31
647
+ const uint ib32 = ib8 / 4; // 0..7
601
648
 
602
649
  const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
603
650
  const uint qs = data_a[ib].qs[ib8];
604
651
  const uint qh = data_a[ib].qh[ib32];
605
652
  const uint qhshift = 2 * (ib8 % 4);
606
- const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
653
+ const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
607
654
 
608
655
  const float d = float(data_a[ib].d);
609
- const float db = d * 0.25 * (0.5 + scale);
610
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
611
- const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
612
- const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
613
-
614
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
615
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
656
+ const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
657
+ const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
658
+ const vec4 grid0 = vec4(unpack8(grid.x));
659
+ const vec4 grid1 = vec4(unpack8(grid.y));
660
+
661
+ buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
662
+ buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
663
+ buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
664
+ buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
665
+ buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
666
+ buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
667
+ buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
668
+ buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
616
669
  #elif defined(DATA_A_IQ3_XXS)
617
670
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
618
671
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
619
672
 
620
- const uint ib = idx / 128; // 2 values per idx
621
- const uint iqs = (idx % 128) / 2; // 0..63
673
+ const uint ib = idx / 64; // 4 values per idx
674
+ const uint iqs = idx % 64; // 0..63
622
675
  const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
623
676
 
624
677
  const float d = float(data_a[ib].d);
@@ -631,33 +684,36 @@ void main() {
631
684
  ));
632
685
  const float db = d * 0.5 * (0.5 + (signs >> 28));
633
686
  const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
634
- const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
635
- const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
636
- const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
637
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
638
-
639
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
640
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
687
+ const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
688
+ const uint grid = iq3xxs_grid[qs];
689
+ const vec4 v = db * vec4(unpack8(grid));
690
+
691
+ buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
692
+ buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
693
+ buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
694
+ buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
641
695
  #elif defined(DATA_A_IQ3_S)
642
696
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
643
697
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
644
698
 
645
- const uint ib = idx / 128; // 2 values per idx
646
- const uint iqs = (idx % 128) / 2; // 0..63
699
+ const uint ib = idx / 64; // 4 values per idx
700
+ const uint iqs = idx % 64; // 0..63
647
701
  const uint iqh = iqs / 8;
648
702
 
649
703
  const float d = float(data_a[ib].d);
650
704
  const uint qs = data_a[ib].qs[iqs];
651
705
  const uint qh = data_a[ib].qh[iqh];
652
- const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
706
+ const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
653
707
  const uint scale = data_a[ib].scales[iqs / 16];
654
708
  const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
655
709
  const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
656
- const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
657
- const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
710
+ const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
711
+ const vec4 v = db * vec4(unpack8(grid));
658
712
 
659
- buf_a[buf_idx ] = FLOAT_TYPE(v.x);
660
- buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
713
+ buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
714
+ buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
715
+ buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
716
+ buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
661
717
  #elif defined(DATA_A_IQ4_XS)
662
718
  const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
663
719
  const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
@@ -162,17 +162,32 @@ void main() {
162
162
  _ne1 = 0;
163
163
  uint num_elements = p.nei1 * p.nei0;
164
164
 
165
- for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
165
+ uint ids[16];
166
+ uint iter = 0;
167
+
168
+ for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
169
+ // prefetch up to 16 elements
170
+ if (iter == 0) {
171
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
172
+ uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
173
+ bool in_range = i < num_elements;
174
+ uint ii1 = i / p.nei0;
175
+ uint ii0 = i % p.nei0;
176
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
177
+ }
178
+ }
179
+ uint i = j + gl_SubgroupInvocationID;
166
180
  bool in_range = i < num_elements;
167
- uint ii0 = i % p.nei0;
168
181
  uint ii1 = i / p.nei0;
169
- uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
182
+ uint ii0 = i % p.nei0;
183
+ uint id = ids[iter++];
170
184
  uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
171
185
  uint idx = subgroupBallotExclusiveBitCount(ballot);
172
186
  if (in_range && id == expert_idx) {
173
187
  row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
174
188
  }
175
189
  _ne1 += subgroupBallotBitCount(ballot);
190
+ iter &= 15;
176
191
  }
177
192
  _ne1_sh = _ne1;
178
193
  }
@@ -414,17 +429,31 @@ void main() {
414
429
  fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
415
430
  }
416
431
 
417
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
418
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
432
+ if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
433
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
434
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
419
435
 
420
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
436
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
421
437
  #ifdef MUL_MAT_ID
422
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
438
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
423
439
  #else
424
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
440
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
425
441
  #endif
426
442
 
427
- sum = coopMatMulAdd(mat_a, mat_b, sum);
443
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
444
+ } else {
445
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
446
+ coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
447
+
448
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
449
+ #ifdef MUL_MAT_ID
450
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
451
+ #else
452
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
453
+ #endif
454
+
455
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
456
+ }
428
457
  }
429
458
 
430
459
  // Convert from ACC_TYPE to D_TYPE
@@ -0,0 +1,9 @@
1
+ #version 450
2
+
3
+ #include "glu_head.comp"
4
+
5
+ float op(float a, float b) {
6
+ return max(a, 0.0f) * b;
7
+ }
8
+
9
+ #include "glu_main.comp"
@@ -1,11 +1,13 @@
1
1
  #version 450
2
2
 
3
- #include "generic_unary_head.comp"
3
+ #include "generic_binary_head.comp"
4
4
  #include "types.comp"
5
5
 
6
6
  #extension GL_EXT_control_flow_attributes : enable
7
7
  #define BLOCK_SIZE 512
8
8
 
9
+ layout (constant_id = 1) const bool do_multiply = false;
10
+
9
11
  layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
10
12
 
11
13
  shared FLOAT_TYPE sum[BLOCK_SIZE];
@@ -25,6 +27,7 @@ void main() {
25
27
  const uint stride_sample = p.nb03;
26
28
 
27
29
  uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
30
+ uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
28
31
  uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
29
32
 
30
33
  sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
@@ -46,7 +49,13 @@ void main() {
46
49
  const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
47
50
  const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
48
51
 
49
- [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
50
- data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
52
+ if (do_multiply) {
53
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
54
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
55
+ }
56
+ } else {
57
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
58
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
59
+ }
51
60
  }
52
61
  }
@@ -0,0 +1,46 @@
1
+ #version 450
2
+
3
+ #include "types.comp"
4
+ #include "generic_unary_head.comp"
5
+
6
+ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7
+
8
+ uint wrap_idx(int i, uint ne) {
9
+ if (i < 0) {
10
+ return i + ne;
11
+ } else if (i >= ne) {
12
+ return i - ne;
13
+ }
14
+ return i;
15
+ }
16
+
17
+ void main() {
18
+ const uint idx = get_idx();
19
+ if (idx >= p.ne) {
20
+ return;
21
+ }
22
+
23
+ const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
24
+ const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10;
25
+ const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L);
26
+ const uint i2_offset = i2*p.ne11*p.ne10;
27
+ const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L);
28
+ const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
29
+
30
+ const uint p1 = floatBitsToUint(p.param1);
31
+ const uint p2 = floatBitsToUint(p.param2);
32
+ const int s0 = int(p1 >> 16) - 0x8000;
33
+ const int s1 = int(p1 & 0xFFFF) - 0x8000;
34
+ const int s2 = int(p2 >> 16) - 0x8000;
35
+ const int s3 = int(p2 & 0xFFFF) - 0x8000;
36
+
37
+ const uint i00 = wrap_idx(int(i0) - s0, p.ne10);
38
+ const uint i01 = wrap_idx(int(i1) - s1, p.ne11);
39
+ const uint i02 = wrap_idx(int(i2) - s2, p.ne12);
40
+ const uint i03 = wrap_idx(int(i3) - s3, p.ne13);
41
+
42
+ const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
43
+ const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10;
44
+
45
+ data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]);
46
+ }
@@ -14,21 +14,19 @@ void main() {
14
14
 
15
15
  const uint row_dst = gl_GlobalInvocationID.x;
16
16
 
17
- if (i0 >= p.n_dims) {
18
- const uint i = row_dst*ne0 + i0;
19
-
20
- data_d[i + 0] = data_a[i + 0];
21
- data_d[i + 1] = data_a[i + 1];
22
-
23
- return;
24
- }
25
-
26
17
  const uint row_x = row_dst % ne1;
27
18
  const uint channel_x = row_dst / ne1;
28
19
 
29
20
  const uint idst = row_dst*ne0 + i0/2;
30
21
  const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
31
22
 
23
+ if (i0 >= p.n_dims) {
24
+ data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
25
+ data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
26
+
27
+ return;
28
+ }
29
+
32
30
  const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3];
33
31
  const int sec_w = p.sections[1] + p.sections[0];
34
32
  const uint sector = (i0 / 2) % sect_dims;
@@ -13,21 +13,19 @@ void main() {
13
13
 
14
14
  const uint row_dst = gl_GlobalInvocationID.x;
15
15
 
16
- if (i0 >= p.n_dims) {
17
- const uint i = row_dst*ne0 + i0;
18
-
19
- data_d[i + 0] = data_a[i + 0];
20
- data_d[i + 1] = data_a[i + 1];
21
-
22
- return;
23
- }
24
-
25
16
  const uint row_x = row_dst % ne1;
26
17
  const uint channel_x = row_dst / ne1;
27
18
 
28
19
  const uint idst = row_dst*ne0 + i0/2;
29
20
  const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2;
30
21
 
22
+ if (i0 >= p.n_dims) {
23
+ data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0];
24
+ data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1];
25
+
26
+ return;
27
+ }
28
+
31
29
  const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
32
30
 
33
31
  const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
@@ -13,21 +13,19 @@ void main() {
13
13
 
14
14
  const uint row_dst = gl_GlobalInvocationID.x;
15
15
 
16
- if (i0 >= p.n_dims) {
17
- const uint i = row_dst*ne0 + i0;
18
-
19
- data_d[i + 0] = data_a[i + 0];
20
- data_d[i + 1] = data_a[i + 1];
21
-
22
- return;
23
- }
24
-
25
16
  const uint row_x = row_dst % ne1;
26
17
  const uint channel_x = row_dst / ne1;
27
18
 
28
19
  const uint idst = row_dst*ne0 + i0;
29
20
  const uint ix = channel_x*p.s2 + row_x*p.s1 + i0;
30
21
 
22
+ if (i0 >= p.n_dims) {
23
+ data_d[idst + 0] = data_a[ix + 0];
24
+ data_d[idst + 1] = data_a[ix + 1];
25
+
26
+ return;
27
+ }
28
+
31
29
  const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f);
32
30
 
33
31
  const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;
@@ -18,7 +18,7 @@ void main() {
18
18
  continue;
19
19
  }
20
20
 
21
- data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
21
+ data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2));
22
22
  idx += num_threads;
23
23
  }
24
24
  }
@@ -6,6 +6,14 @@ layout (push_constant) uniform parameter
6
6
  {
7
7
  uint KX;
8
8
  uint KY;
9
+ uint ne00;
10
+ uint ne01;
11
+ uint ne02;
12
+ uint ne12;
13
+ uint ne13;
14
+ uint nb11;
15
+ uint nb12;
16
+ uint nb13;
9
17
  float scale;
10
18
  float max_bias;
11
19
  float m0;
@@ -31,7 +39,15 @@ shared FLOAT_TYPE vals[BLOCK_SIZE];
31
39
  void soft_max(uint num_iters) {
32
40
  const uint tid = gl_LocalInvocationID.x;
33
41
  const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
34
- const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0;
42
+
43
+ const uint32_t i03 = rowx / (p.ne01 * p.ne02);
44
+ const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
45
+ const uint32_t i01 = rowx % p.ne01;
46
+
47
+ uint rowy_start = 0;
48
+ if (p.KY > 0) {
49
+ rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
50
+ }
35
51
 
36
52
  if (rowx >= p.nrows_x) {
37
53
  return;
@@ -41,7 +57,7 @@ void soft_max(uint num_iters) {
41
57
 
42
58
  // ALiBi
43
59
  if (p.max_bias > 0.0f) {
44
- const uint h = rowx/p.KY; // head index
60
+ const uint h = (rowx / p.ne01) % p.ne02; // head index
45
61
 
46
62
  const float base = h < p.n_head_log2 ? p.m0 : p.m1;
47
63
  const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
@@ -67,7 +83,7 @@ void soft_max(uint num_iters) {
67
83
 
68
84
  FLOAT_TYPE b = FLOAT_TYPE(0);
69
85
  if (p.KY > 0 && col < p.KX) {
70
- b = data_b[rowy * p.KX + col];
86
+ b = data_b[rowy_start + col];
71
87
  }
72
88
 
73
89
  FLOAT_TYPE v = a * p.scale + slope * b;
@@ -111,7 +127,7 @@ void soft_max(uint num_iters) {
111
127
  if (idx < DATA_CACHE_SIZE) {
112
128
  val = exp(data_cache[idx] - max_val);
113
129
  } else {
114
- val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
130
+ val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val);
115
131
  }
116
132
  sum += val;
117
133
  if (idx < DATA_CACHE_SIZE) {
@@ -0,0 +1,9 @@
1
+ #version 450
2
+
3
+ #include "glu_head.comp"
4
+
5
+ float op(float a, float b) {
6
+ return a / (1.0f + exp(-a)) * b;
7
+ }
8
+
9
+ #include "glu_main.comp"