@novastera-oss/llamarn 0.4.1 → 0.4.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (976) hide show
  1. package/RNLlamaCpp.podspec +3 -0
  2. package/android/CMakeLists.txt +2 -0
  3. package/android/src/main/cpp/include/llama.h +44 -21
  4. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  6. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  7. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  8. package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
  9. package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
  10. package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
  12. package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
  13. package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
  14. package/android/src/main/jniLibs/x86/libggml.so +0 -0
  15. package/android/src/main/jniLibs/x86/libllama.so +0 -0
  16. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  17. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  18. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  19. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  20. package/cpp/build-info.cpp +2 -2
  21. package/cpp/llama.cpp/CMakeLists.txt +12 -0
  22. package/cpp/llama.cpp/CODEOWNERS +116 -10
  23. package/cpp/llama.cpp/CONTRIBUTING.md +30 -3
  24. package/cpp/llama.cpp/README.md +13 -5
  25. package/cpp/llama.cpp/build-xcframework.sh +5 -0
  26. package/cpp/llama.cpp/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  27. package/cpp/llama.cpp/common/CMakeLists.txt +12 -2
  28. package/cpp/llama.cpp/common/arg.cpp +303 -795
  29. package/cpp/llama.cpp/common/arg.h +2 -3
  30. package/cpp/llama.cpp/common/chat-parser-xml-toolcall.cpp +861 -0
  31. package/cpp/llama.cpp/common/chat-parser-xml-toolcall.h +45 -0
  32. package/cpp/llama.cpp/common/chat-parser.cpp +156 -15
  33. package/cpp/llama.cpp/common/chat-parser.h +13 -0
  34. package/cpp/llama.cpp/common/chat.cpp +1147 -88
  35. package/cpp/llama.cpp/common/chat.h +16 -3
  36. package/cpp/llama.cpp/common/common.cpp +70 -15
  37. package/cpp/llama.cpp/common/common.h +57 -19
  38. package/cpp/llama.cpp/common/download.cpp +1072 -0
  39. package/cpp/llama.cpp/common/download.h +55 -0
  40. package/cpp/llama.cpp/common/http.h +73 -0
  41. package/cpp/llama.cpp/common/json-partial.cpp +70 -2
  42. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +61 -22
  43. package/cpp/llama.cpp/common/json-schema-to-grammar.h +2 -0
  44. package/cpp/llama.cpp/common/log.cpp +59 -2
  45. package/cpp/llama.cpp/common/log.h +12 -4
  46. package/cpp/llama.cpp/common/sampling.cpp +84 -8
  47. package/cpp/llama.cpp/common/sampling.h +3 -1
  48. package/cpp/llama.cpp/common/speculative.cpp +1 -1
  49. package/cpp/llama.cpp/convert_hf_to_gguf.py +1608 -233
  50. package/cpp/llama.cpp/convert_hf_to_gguf_update.py +6 -1
  51. package/cpp/llama.cpp/convert_lora_to_gguf.py +37 -5
  52. package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -28
  53. package/cpp/llama.cpp/ggml/include/ggml-backend.h +19 -1
  54. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +1 -1
  55. package/cpp/llama.cpp/ggml/include/ggml-hexagon.h +19 -0
  56. package/cpp/llama.cpp/ggml/include/ggml-metal.h +1 -6
  57. package/cpp/llama.cpp/ggml/include/ggml-rpc.h +7 -9
  58. package/cpp/llama.cpp/ggml/include/ggml-zdnn.h +2 -1
  59. package/cpp/llama.cpp/ggml/include/ggml.h +199 -6
  60. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +38 -0
  61. package/cpp/llama.cpp/ggml/src/ggml-alloc.c +299 -130
  62. package/cpp/llama.cpp/ggml/src/ggml-backend-impl.h +4 -4
  63. package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +21 -5
  64. package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +99 -2
  65. package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  66. package/cpp/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +1 -0
  67. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  68. package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +138 -47
  69. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +1584 -1773
  70. package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +201 -317
  71. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +146 -187
  72. package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +771 -713
  73. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +135 -77
  74. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  75. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  76. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +16 -17
  77. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +318 -145
  78. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  79. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +155 -60
  80. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +8 -8
  81. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +0 -1
  82. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +14 -0
  83. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -9
  84. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +108 -64
  85. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +14 -4
  86. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +530 -87
  87. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +37 -45
  88. package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +349 -127
  89. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +947 -1218
  90. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +5 -4
  91. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +143 -29
  92. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +82 -76
  93. package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  94. package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  95. package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  96. package/cpp/llama.cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  97. package/cpp/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  98. package/cpp/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +7 -0
  99. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +233 -28
  100. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +326 -66
  101. package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +12 -3
  102. package/cpp/llama.cpp/ggml/src/ggml-cuda/argsort.cu +102 -6
  103. package/cpp/llama.cpp/ggml/src/ggml-cuda/binbcast.cu +110 -76
  104. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +167 -38
  105. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d.cu +6 -11
  106. package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +12 -0
  107. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  108. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +245 -151
  109. package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cuh +1 -5
  110. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +341 -289
  111. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  112. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cuh +1233 -0
  113. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  114. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +6 -6
  115. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  116. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +123 -220
  117. package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +41 -39
  118. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +715 -45
  119. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +150 -0
  120. package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cuh +1 -0
  121. package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +321 -24
  122. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cu +93 -351
  123. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmf.cuh +828 -1
  124. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmid.cu +164 -0
  125. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmid.cuh +5 -0
  126. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +3 -166
  127. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1 -1
  128. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvf.cu +371 -78
  129. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  130. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cu +279 -147
  131. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  132. package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +97 -85
  133. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad.cu +46 -23
  134. package/cpp/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu +63 -54
  135. package/cpp/llama.cpp/ggml/src/ggml-cuda/quantize.cu +12 -10
  136. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +192 -77
  137. package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cuh +2 -0
  138. package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +10 -9
  139. package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +137 -75
  140. package/cpp/llama.cpp/ggml/src/ggml-cuda/set.cu +39 -0
  141. package/cpp/llama.cpp/ggml/src/ggml-cuda/set.cuh +7 -0
  142. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  143. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  144. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  145. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  146. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  147. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  148. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  149. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  150. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  151. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  152. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  153. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  154. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  155. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  156. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  157. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  158. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  159. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  160. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  161. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  162. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  163. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  164. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  165. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  166. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  167. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  168. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  169. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  170. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  171. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  172. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  173. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  174. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  175. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  176. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  177. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  178. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  179. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  180. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  181. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  182. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  183. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  184. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  185. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  186. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  187. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  188. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  189. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  190. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  191. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  192. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  193. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  194. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  195. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  196. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  197. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  198. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  199. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  200. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  201. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  202. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  203. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  204. package/cpp/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu +336 -0
  205. package/cpp/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh +16 -0
  206. package/cpp/llama.cpp/ggml/src/ggml-cuda/tsembd.cu +3 -3
  207. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +105 -11
  208. package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +36 -0
  209. package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +87 -6
  210. package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +28 -12
  211. package/cpp/llama.cpp/ggml/src/ggml-hexagon/CMakeLists.txt +68 -0
  212. package/cpp/llama.cpp/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3807 -0
  213. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/CMakeLists.txt +40 -0
  214. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/act-ops.c +442 -0
  215. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  216. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  217. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ctx.h +40 -0
  218. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-dma.c +69 -0
  219. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-dma.h +119 -0
  220. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-msg.h +156 -0
  221. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp-ops.h +64 -0
  222. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  223. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-exp.c +93 -0
  224. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-inverse.c +60 -0
  225. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  226. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.c +960 -0
  227. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/hvx-utils.h +1032 -0
  228. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/main.c +829 -0
  229. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/matmul-ops.c +2223 -0
  230. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  231. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/rope-ops.c +418 -0
  232. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  233. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/unary-ops.c +255 -0
  234. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  235. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  236. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp-utils.c +448 -0
  237. package/cpp/llama.cpp/ggml/src/ggml-hexagon/htp-utils.h +220 -0
  238. package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  239. package/cpp/llama.cpp/ggml/src/ggml-impl.h +110 -12
  240. package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +6 -5
  241. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  242. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  243. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  244. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m +599 -0
  245. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp +1662 -0
  246. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h +251 -0
  247. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m +1527 -0
  248. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +244 -39
  249. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp +3844 -0
  250. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h +90 -0
  251. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.cpp +723 -0
  252. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +3453 -1907
  253. package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  254. package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +10 -0
  255. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1331 -109
  256. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/cvt.cl +126 -0
  257. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +31 -4
  258. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +35 -7
  259. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +31 -4
  260. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  261. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  262. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  263. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  264. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  265. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  266. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  267. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  268. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  269. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  270. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  271. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  272. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  273. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  274. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  275. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  276. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +123 -10
  277. package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  278. package/cpp/llama.cpp/ggml/src/ggml-quants.c +1 -0
  279. package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +341 -161
  280. package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  281. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  282. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +74 -15
  283. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +50 -30
  284. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +10 -4
  285. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +166 -99
  286. package/cpp/llama.cpp/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  287. package/cpp/llama.cpp/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  288. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +72 -94
  289. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  290. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +21 -31
  291. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +252 -316
  292. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +6 -2
  293. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +9 -6
  294. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +359 -142
  295. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  296. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  297. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +80 -60
  298. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  299. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +230 -55
  300. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.hpp +2 -0
  301. package/cpp/llama.cpp/ggml/src/ggml-sycl/pad.cpp +97 -0
  302. package/cpp/llama.cpp/ggml/src/ggml-sycl/pad.hpp +24 -0
  303. package/cpp/llama.cpp/ggml/src/ggml-sycl/pad_reflect_1d.cpp +72 -0
  304. package/cpp/llama.cpp/ggml/src/ggml-sycl/pad_reflect_1d.hpp +8 -0
  305. package/cpp/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  306. package/cpp/llama.cpp/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  307. package/cpp/llama.cpp/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  308. package/cpp/llama.cpp/ggml/src/ggml-sycl/roll.cpp +122 -0
  309. package/cpp/llama.cpp/ggml/src/ggml-sycl/roll.hpp +20 -0
  310. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +50 -41
  311. package/cpp/llama.cpp/ggml/src/ggml-sycl/set.cpp +73 -0
  312. package/cpp/llama.cpp/ggml/src/ggml-sycl/set.hpp +5 -0
  313. package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +45 -36
  314. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +330 -165
  315. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +4 -0
  316. package/cpp/llama.cpp/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  317. package/cpp/llama.cpp/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  318. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  319. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +16 -12
  320. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  321. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4184 -2159
  322. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  323. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  324. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  325. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  326. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  327. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  328. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  329. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  330. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  331. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  332. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  333. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  334. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  335. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  336. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +53 -30
  337. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  338. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  339. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  340. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +13 -6
  341. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  342. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  343. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  344. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  345. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +138 -2
  346. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  347. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  348. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  349. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  350. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  351. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  352. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  353. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  354. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  355. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  356. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  357. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  358. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  359. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  360. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  361. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  362. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  363. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  364. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  365. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  366. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  367. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  368. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  369. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  370. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -2
  371. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  372. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +52 -14
  373. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +50 -12
  374. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +61 -12
  375. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +54 -12
  376. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +5 -1
  377. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  378. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  379. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  380. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  381. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  382. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  383. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  384. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +10 -2
  385. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  386. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  387. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  388. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  389. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  390. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  391. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +15 -7
  392. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  393. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  394. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  395. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  396. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  397. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +1 -1
  398. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +229 -0
  399. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +33 -0
  400. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +1 -1
  401. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +1 -1
  402. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +1 -1
  403. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +1 -1
  404. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +1 -1
  405. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +1 -1
  406. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +1 -1
  407. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  408. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  409. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +3 -5
  410. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +1 -1
  411. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +3 -5
  412. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +3 -5
  413. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +1 -1
  414. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  415. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +106 -634
  416. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +118 -9
  417. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +556 -0
  418. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +70 -0
  419. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +77 -214
  420. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +589 -0
  421. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  422. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  423. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  424. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  425. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  426. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  427. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +25 -4
  428. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  429. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +55 -5
  430. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  431. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  432. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  433. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  434. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +45 -3
  435. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  436. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  437. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  438. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +227 -0
  439. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  440. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +5 -52
  441. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +5 -35
  442. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +5 -35
  443. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +27 -0
  444. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +5 -41
  445. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  446. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  447. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  448. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  449. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  450. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  451. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  452. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  453. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  454. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  455. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  456. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  457. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +140 -0
  458. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  459. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  460. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +1 -1
  461. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  462. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  463. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  464. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  465. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +171 -0
  466. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  467. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +79 -29
  468. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -12
  469. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +471 -196
  470. package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +8 -0
  471. package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1690 -383
  472. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  473. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  474. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  475. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  476. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +57 -10
  477. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  478. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  479. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +25 -912
  480. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  481. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  482. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  483. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  484. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  485. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  486. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  487. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/{set_rows.wgsl → set_rows.tmpl.wgsl} +38 -8
  488. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  489. package/cpp/llama.cpp/ggml/src/ggml-zdnn/common.hpp +59 -0
  490. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn.cpp +96 -314
  491. package/cpp/llama.cpp/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  492. package/cpp/llama.cpp/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  493. package/cpp/llama.cpp/ggml/src/ggml-zdnn/utils.cpp +79 -0
  494. package/cpp/llama.cpp/ggml/src/ggml-zdnn/utils.hpp +19 -0
  495. package/cpp/llama.cpp/ggml/src/ggml.c +440 -17
  496. package/cpp/llama.cpp/ggml/src/gguf.cpp +104 -29
  497. package/cpp/llama.cpp/gguf-py/gguf/constants.py +363 -13
  498. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +64 -0
  499. package/cpp/llama.cpp/gguf-py/gguf/lazy.py +8 -3
  500. package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_convert_endian.py +6 -0
  501. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +156 -18
  502. package/cpp/llama.cpp/gguf-py/gguf/utility.py +80 -0
  503. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +4 -4
  504. package/cpp/llama.cpp/include/llama.h +44 -21
  505. package/cpp/llama.cpp/media/llama1-icon-transparent.png +0 -0
  506. package/cpp/llama.cpp/media/llama1-icon-transparent.svg +77 -0
  507. package/cpp/llama.cpp/media/llama1-icon.png +0 -0
  508. package/cpp/llama.cpp/media/llama1-icon.svg +87 -0
  509. package/cpp/llama.cpp/requirements/requirements-all.txt +2 -0
  510. package/cpp/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -3
  511. package/cpp/llama.cpp/requirements/requirements-convert_legacy_llama.txt +3 -1
  512. package/cpp/llama.cpp/requirements/requirements-tool_bench.txt +1 -1
  513. package/cpp/llama.cpp/src/CMakeLists.txt +101 -0
  514. package/cpp/llama.cpp/src/llama-adapter.cpp +33 -0
  515. package/cpp/llama.cpp/src/llama-adapter.h +3 -0
  516. package/cpp/llama.cpp/src/llama-arch.cpp +344 -14
  517. package/cpp/llama.cpp/src/llama-arch.h +50 -0
  518. package/cpp/llama.cpp/src/llama-batch.cpp +63 -31
  519. package/cpp/llama.cpp/src/llama-batch.h +13 -2
  520. package/cpp/llama.cpp/src/llama-chat.cpp +85 -3
  521. package/cpp/llama.cpp/src/llama-chat.h +4 -0
  522. package/cpp/llama.cpp/src/llama-context.cpp +300 -45
  523. package/cpp/llama.cpp/src/llama-context.h +16 -6
  524. package/cpp/llama.cpp/src/llama-cparams.h +2 -1
  525. package/cpp/llama.cpp/src/llama-grammar.cpp +17 -9
  526. package/cpp/llama.cpp/src/llama-graph.cpp +226 -64
  527. package/cpp/llama.cpp/src/llama-graph.h +27 -5
  528. package/cpp/llama.cpp/src/llama-hparams.cpp +53 -2
  529. package/cpp/llama.cpp/src/llama-hparams.h +48 -8
  530. package/cpp/llama.cpp/src/llama-impl.cpp +3 -3
  531. package/cpp/llama.cpp/src/llama-impl.h +2 -0
  532. package/cpp/llama.cpp/src/llama-kv-cache-iswa.cpp +13 -3
  533. package/cpp/llama.cpp/src/llama-kv-cache-iswa.h +2 -0
  534. package/cpp/llama.cpp/src/llama-kv-cache.cpp +120 -62
  535. package/cpp/llama.cpp/src/llama-kv-cache.h +13 -4
  536. package/cpp/llama.cpp/src/llama-kv-cells.h +44 -2
  537. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +19 -9
  538. package/cpp/llama.cpp/src/llama-memory-hybrid.h +2 -0
  539. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +38 -17
  540. package/cpp/llama.cpp/src/llama-memory-recurrent.h +5 -2
  541. package/cpp/llama.cpp/src/llama-memory.h +3 -0
  542. package/cpp/llama.cpp/src/llama-model-loader.cpp +2 -0
  543. package/cpp/llama.cpp/src/llama-model.cpp +1070 -12614
  544. package/cpp/llama.cpp/src/llama-model.h +40 -4
  545. package/cpp/llama.cpp/src/llama-quant.cpp +14 -6
  546. package/cpp/llama.cpp/src/llama-sampling.cpp +243 -136
  547. package/cpp/llama.cpp/src/llama-vocab.cpp +43 -3
  548. package/cpp/llama.cpp/src/llama-vocab.h +43 -39
  549. package/cpp/llama.cpp/src/llama.cpp +69 -10
  550. package/cpp/llama.cpp/src/models/afmoe.cpp +187 -0
  551. package/cpp/llama.cpp/src/models/apertus.cpp +125 -0
  552. package/cpp/llama.cpp/src/models/arcee.cpp +135 -0
  553. package/cpp/llama.cpp/src/models/arctic.cpp +138 -0
  554. package/cpp/llama.cpp/src/models/arwkv7.cpp +86 -0
  555. package/cpp/llama.cpp/src/models/baichuan.cpp +122 -0
  556. package/cpp/llama.cpp/src/models/bailingmoe.cpp +144 -0
  557. package/cpp/llama.cpp/src/models/bailingmoe2.cpp +135 -0
  558. package/cpp/llama.cpp/src/models/bert.cpp +176 -0
  559. package/cpp/llama.cpp/src/models/bitnet.cpp +160 -0
  560. package/cpp/llama.cpp/src/models/bloom.cpp +101 -0
  561. package/cpp/llama.cpp/src/models/chameleon.cpp +178 -0
  562. package/cpp/llama.cpp/src/models/chatglm.cpp +132 -0
  563. package/cpp/llama.cpp/src/models/codeshell.cpp +111 -0
  564. package/cpp/llama.cpp/src/models/cogvlm.cpp +100 -0
  565. package/cpp/llama.cpp/src/models/cohere2-iswa.cpp +131 -0
  566. package/cpp/llama.cpp/src/models/command-r.cpp +122 -0
  567. package/cpp/llama.cpp/src/models/dbrx.cpp +123 -0
  568. package/cpp/llama.cpp/src/models/deci.cpp +135 -0
  569. package/cpp/llama.cpp/src/models/deepseek.cpp +144 -0
  570. package/cpp/llama.cpp/src/models/deepseek2.cpp +237 -0
  571. package/cpp/llama.cpp/src/models/dots1.cpp +134 -0
  572. package/cpp/llama.cpp/src/models/dream.cpp +105 -0
  573. package/cpp/llama.cpp/src/models/ernie4-5-moe.cpp +150 -0
  574. package/cpp/llama.cpp/src/models/ernie4-5.cpp +110 -0
  575. package/cpp/llama.cpp/src/models/exaone.cpp +114 -0
  576. package/cpp/llama.cpp/src/models/exaone4.cpp +123 -0
  577. package/cpp/llama.cpp/src/models/falcon-h1.cpp +113 -0
  578. package/cpp/llama.cpp/src/models/falcon.cpp +120 -0
  579. package/cpp/llama.cpp/src/models/gemma-embedding.cpp +120 -0
  580. package/cpp/llama.cpp/src/models/gemma.cpp +112 -0
  581. package/cpp/llama.cpp/src/models/gemma2-iswa.cpp +125 -0
  582. package/cpp/llama.cpp/src/models/gemma3-iswa.cpp +131 -0
  583. package/cpp/llama.cpp/src/models/gemma3n-iswa.cpp +377 -0
  584. package/cpp/llama.cpp/src/models/glm4-moe.cpp +153 -0
  585. package/cpp/llama.cpp/src/models/glm4.cpp +127 -0
  586. package/cpp/llama.cpp/src/models/gpt2.cpp +105 -0
  587. package/cpp/llama.cpp/src/models/gptneox.cpp +144 -0
  588. package/cpp/llama.cpp/src/models/granite-hybrid.cpp +196 -0
  589. package/cpp/llama.cpp/src/models/granite.cpp +211 -0
  590. package/cpp/llama.cpp/src/models/graph-context-mamba.cpp +283 -0
  591. package/cpp/llama.cpp/src/models/grok.cpp +159 -0
  592. package/cpp/llama.cpp/src/models/grovemoe.cpp +141 -0
  593. package/cpp/llama.cpp/src/models/hunyuan-dense.cpp +132 -0
  594. package/cpp/llama.cpp/src/models/hunyuan-moe.cpp +154 -0
  595. package/cpp/llama.cpp/src/models/internlm2.cpp +120 -0
  596. package/cpp/llama.cpp/src/models/jais.cpp +86 -0
  597. package/cpp/llama.cpp/src/models/jamba.cpp +106 -0
  598. package/cpp/llama.cpp/src/models/lfm2.cpp +173 -0
  599. package/cpp/llama.cpp/src/models/llada-moe.cpp +122 -0
  600. package/cpp/llama.cpp/src/models/llada.cpp +99 -0
  601. package/cpp/llama.cpp/src/models/llama-iswa.cpp +174 -0
  602. package/cpp/llama.cpp/src/models/llama.cpp +155 -0
  603. package/cpp/llama.cpp/src/models/mamba.cpp +55 -0
  604. package/cpp/llama.cpp/src/models/minicpm3.cpp +199 -0
  605. package/cpp/llama.cpp/src/models/minimax-m2.cpp +124 -0
  606. package/cpp/llama.cpp/src/models/models.h +485 -0
  607. package/cpp/llama.cpp/src/models/mpt.cpp +126 -0
  608. package/cpp/llama.cpp/src/models/nemotron-h.cpp +121 -0
  609. package/cpp/llama.cpp/src/models/nemotron.cpp +122 -0
  610. package/cpp/llama.cpp/src/models/neo-bert.cpp +104 -0
  611. package/cpp/llama.cpp/src/models/olmo.cpp +121 -0
  612. package/cpp/llama.cpp/src/models/olmo2.cpp +150 -0
  613. package/cpp/llama.cpp/src/models/olmoe.cpp +124 -0
  614. package/cpp/llama.cpp/src/models/openai-moe-iswa.cpp +124 -0
  615. package/cpp/llama.cpp/src/models/openelm.cpp +124 -0
  616. package/cpp/llama.cpp/src/models/orion.cpp +123 -0
  617. package/cpp/llama.cpp/src/models/pangu-embedded.cpp +121 -0
  618. package/cpp/llama.cpp/src/models/phi2.cpp +121 -0
  619. package/cpp/llama.cpp/src/models/phi3.cpp +152 -0
  620. package/cpp/llama.cpp/src/models/plamo.cpp +110 -0
  621. package/cpp/llama.cpp/src/models/plamo2.cpp +316 -0
  622. package/cpp/llama.cpp/src/models/plm.cpp +168 -0
  623. package/cpp/llama.cpp/src/models/qwen.cpp +108 -0
  624. package/cpp/llama.cpp/src/models/qwen2.cpp +117 -0
  625. package/cpp/llama.cpp/src/models/qwen2moe.cpp +151 -0
  626. package/cpp/llama.cpp/src/models/qwen2vl.cpp +117 -0
  627. package/cpp/llama.cpp/src/models/qwen3.cpp +117 -0
  628. package/cpp/llama.cpp/src/models/qwen3moe.cpp +124 -0
  629. package/cpp/llama.cpp/src/models/qwen3vl-moe.cpp +149 -0
  630. package/cpp/llama.cpp/src/models/qwen3vl.cpp +141 -0
  631. package/cpp/llama.cpp/src/models/refact.cpp +94 -0
  632. package/cpp/llama.cpp/src/models/rwkv6-base.cpp +162 -0
  633. package/cpp/llama.cpp/src/models/rwkv6.cpp +94 -0
  634. package/cpp/llama.cpp/src/models/rwkv6qwen2.cpp +86 -0
  635. package/cpp/llama.cpp/src/models/rwkv7-base.cpp +135 -0
  636. package/cpp/llama.cpp/src/models/rwkv7.cpp +90 -0
  637. package/cpp/llama.cpp/src/models/seed-oss.cpp +124 -0
  638. package/cpp/llama.cpp/src/models/smallthinker.cpp +120 -0
  639. package/cpp/llama.cpp/src/models/smollm3.cpp +128 -0
  640. package/cpp/llama.cpp/src/models/stablelm.cpp +146 -0
  641. package/cpp/llama.cpp/src/models/starcoder.cpp +100 -0
  642. package/cpp/llama.cpp/src/models/starcoder2.cpp +121 -0
  643. package/cpp/llama.cpp/src/models/t5-dec.cpp +166 -0
  644. package/cpp/llama.cpp/src/models/t5-enc.cpp +96 -0
  645. package/cpp/llama.cpp/src/models/wavtokenizer-dec.cpp +149 -0
  646. package/cpp/llama.cpp/src/models/xverse.cpp +108 -0
  647. package/cpp/llama.cpp/src/unicode.cpp +77 -0
  648. package/cpp/llama.cpp/src/unicode.h +43 -0
  649. package/cpp/llama.cpp/vendor/cpp-httplib/CMakeLists.txt +94 -0
  650. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.cpp +9339 -0
  651. package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +433 -8222
  652. package/cpp/llama.cpp/vendor/cpp-httplib/patch-boringssl.cmake +6 -0
  653. package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +4179 -1900
  654. package/cpp/llama.cpp/vendor/minja/chat-template.hpp +9 -2
  655. package/cpp/llama.cpp/vendor/minja/minja.hpp +101 -22
  656. package/ios/include/chat.h +16 -3
  657. package/ios/include/common/minja/chat-template.hpp +9 -2
  658. package/ios/include/common/minja/minja.hpp +101 -22
  659. package/ios/include/common.h +57 -19
  660. package/ios/include/json-schema-to-grammar.h +2 -0
  661. package/ios/include/llama.h +44 -21
  662. package/ios/include/log.h +12 -4
  663. package/ios/include/sampling.h +3 -1
  664. package/ios/libs/llama.xcframework/Info.plist +20 -20
  665. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  666. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +6399 -5557
  667. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +19 -1
  668. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +1 -1
  669. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-metal.h +1 -6
  670. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +199 -6
  671. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +44 -21
  672. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  673. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  674. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +6362 -5520
  675. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4813 -4241
  676. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +19 -1
  677. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +1 -1
  678. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-metal.h +1 -6
  679. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +199 -6
  680. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +44 -21
  681. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  682. package/package.json +10 -4
  683. package/cpp/llama.cpp/ggml/src/ggml-cann/Doxyfile +0 -2579
  684. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -371
  685. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  686. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -379
  687. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  688. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -495
  689. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -486
  690. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  691. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  692. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  693. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  694. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  695. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  696. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  697. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  698. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  699. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  700. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  701. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  702. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  703. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  704. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  705. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  706. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  707. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  708. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  709. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  710. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  711. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  712. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  713. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  714. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  715. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  716. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  717. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  718. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  719. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  720. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  721. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  722. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  723. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  724. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  725. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  726. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  727. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  728. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  729. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  730. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  731. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  732. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  733. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  734. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  735. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  736. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  737. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  738. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  739. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  740. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  741. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  742. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  743. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  744. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  745. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  746. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  747. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  748. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  749. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  750. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  751. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  752. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  753. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  754. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  755. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  756. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  757. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  758. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  759. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  760. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  761. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  762. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  763. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  764. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  765. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  766. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  767. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  768. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  769. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  770. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  771. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  772. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  773. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  774. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  775. package/cpp/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  776. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +0 -6886
  777. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -154
  778. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  779. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  780. package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  781. package/cpp/llama.cpp/ggml/src/ggml-zdnn/ggml-zdnn-impl.h +0 -97
  782. package/cpp/llama.cpp/models/ggml-vocab-aquila.gguf +0 -0
  783. package/cpp/llama.cpp/models/ggml-vocab-baichuan.gguf +0 -0
  784. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf +0 -0
  785. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +0 -112
  786. package/cpp/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +0 -46
  787. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf +0 -0
  788. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.inp +0 -112
  789. package/cpp/llama.cpp/models/ggml-vocab-command-r.gguf.out +0 -46
  790. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf +0 -0
  791. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +0 -112
  792. package/cpp/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +0 -46
  793. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf +0 -0
  794. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +0 -112
  795. package/cpp/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +0 -46
  796. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf +0 -0
  797. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.inp +0 -112
  798. package/cpp/llama.cpp/models/ggml-vocab-falcon.gguf.out +0 -46
  799. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf +0 -0
  800. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +0 -112
  801. package/cpp/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +0 -46
  802. package/cpp/llama.cpp/models/ggml-vocab-gpt-neox.gguf +0 -0
  803. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf +0 -0
  804. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +0 -112
  805. package/cpp/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +0 -46
  806. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf +0 -0
  807. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +0 -112
  808. package/cpp/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +0 -46
  809. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf +0 -0
  810. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.inp +0 -112
  811. package/cpp/llama.cpp/models/ggml-vocab-mpt.gguf.out +0 -46
  812. package/cpp/llama.cpp/models/ggml-vocab-nomic-bert-moe.gguf +0 -0
  813. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf +0 -0
  814. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +0 -112
  815. package/cpp/llama.cpp/models/ggml-vocab-phi-3.gguf.out +0 -46
  816. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf +0 -0
  817. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +0 -112
  818. package/cpp/llama.cpp/models/ggml-vocab-qwen2.gguf.out +0 -46
  819. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf +0 -0
  820. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.inp +0 -112
  821. package/cpp/llama.cpp/models/ggml-vocab-refact.gguf.out +0 -46
  822. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf +0 -0
  823. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +0 -112
  824. package/cpp/llama.cpp/models/ggml-vocab-starcoder.gguf.out +0 -46
  825. package/cpp/llama.cpp/models/templates/ByteDance-Seed-OSS.jinja +0 -171
  826. package/cpp/llama.cpp/models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja +0 -202
  827. package/cpp/llama.cpp/models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja +0 -156
  828. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +0 -124
  829. package/cpp/llama.cpp/models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja +0 -152
  830. package/cpp/llama.cpp/models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja +0 -152
  831. package/cpp/llama.cpp/models/templates/Qwen-QwQ-32B.jinja +0 -62
  832. package/cpp/llama.cpp/models/templates/Qwen-Qwen2.5-7B-Instruct.jinja +0 -54
  833. package/cpp/llama.cpp/models/templates/Qwen-Qwen3-0.6B.jinja +0 -85
  834. package/cpp/llama.cpp/models/templates/README.md +0 -25
  835. package/cpp/llama.cpp/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja +0 -1
  836. package/cpp/llama.cpp/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja +0 -1
  837. package/cpp/llama.cpp/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja +0 -57
  838. package/cpp/llama.cpp/models/templates/google-gemma-2-2b-it.jinja +0 -4
  839. package/cpp/llama.cpp/models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja +0 -59
  840. package/cpp/llama.cpp/models/templates/llama-cpp-deepseek-r1.jinja +0 -76
  841. package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +0 -34
  842. package/cpp/llama.cpp/models/templates/meetkai-functionary-medium-v3.1.jinja +0 -58
  843. package/cpp/llama.cpp/models/templates/meetkai-functionary-medium-v3.2.jinja +0 -287
  844. package/cpp/llama.cpp/models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja +0 -109
  845. package/cpp/llama.cpp/models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja +0 -93
  846. package/cpp/llama.cpp/models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja +0 -109
  847. package/cpp/llama.cpp/models/templates/microsoft-Phi-3.5-mini-instruct.jinja +0 -8
  848. package/cpp/llama.cpp/models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja +0 -87
  849. package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +0 -43
  850. package/cpp/llama.cpp/models/templates/openai-gpt-oss-120b.jinja +0 -331
  851. package/cpp/llama.cpp/models/templates/unsloth-mistral-Devstral-Small-2507.jinja +0 -105
  852. package/cpp/llama.cpp/prompts/LLM-questions.txt +0 -49
  853. package/cpp/llama.cpp/prompts/alpaca.txt +0 -1
  854. package/cpp/llama.cpp/prompts/assistant.txt +0 -31
  855. package/cpp/llama.cpp/prompts/chat-with-baichuan.txt +0 -4
  856. package/cpp/llama.cpp/prompts/chat-with-bob.txt +0 -7
  857. package/cpp/llama.cpp/prompts/chat-with-qwen.txt +0 -1
  858. package/cpp/llama.cpp/prompts/chat-with-vicuna-v0.txt +0 -7
  859. package/cpp/llama.cpp/prompts/chat-with-vicuna-v1.txt +0 -7
  860. package/cpp/llama.cpp/prompts/chat.txt +0 -28
  861. package/cpp/llama.cpp/prompts/dan-modified.txt +0 -1
  862. package/cpp/llama.cpp/prompts/dan.txt +0 -1
  863. package/cpp/llama.cpp/prompts/mnemonics.txt +0 -93
  864. package/cpp/llama.cpp/prompts/parallel-questions.txt +0 -43
  865. package/cpp/llama.cpp/prompts/reason-act.txt +0 -18
  866. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  867. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  868. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5524
  869. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +0 -4247
  870. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-alloc.h +0 -76
  871. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +0 -354
  872. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-blas.h +0 -25
  873. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +0 -145
  874. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-metal.h +0 -66
  875. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-opt.h +0 -256
  876. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +0 -2492
  877. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/gguf.h +0 -202
  878. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +0 -1391
  879. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Modules/module.modulemap +0 -17
  880. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Resources/Info.plist +0 -32
  881. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-alloc.h +0 -76
  882. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +0 -354
  883. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-blas.h +0 -25
  884. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +0 -145
  885. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-metal.h +0 -66
  886. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-opt.h +0 -256
  887. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +0 -2492
  888. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/gguf.h +0 -202
  889. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +0 -1391
  890. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Modules/module.modulemap +0 -17
  891. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Resources/Info.plist +0 -32
  892. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  893. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-alloc.h +0 -76
  894. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +0 -354
  895. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-blas.h +0 -25
  896. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +0 -145
  897. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-metal.h +0 -66
  898. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-opt.h +0 -256
  899. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +0 -2492
  900. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/gguf.h +0 -202
  901. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +0 -1391
  902. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Modules/module.modulemap +0 -17
  903. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Resources/Info.plist +0 -32
  904. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  905. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  906. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  907. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  908. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5561
  909. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-alloc.h +0 -76
  910. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +0 -354
  911. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-blas.h +0 -25
  912. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +0 -145
  913. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-metal.h +0 -66
  914. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-opt.h +0 -256
  915. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +0 -2492
  916. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/gguf.h +0 -202
  917. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +0 -1391
  918. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Info.plist +0 -35
  919. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Modules/module.modulemap +0 -17
  920. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  921. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  922. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  923. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5524
  924. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +0 -4246
  925. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-alloc.h +0 -76
  926. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +0 -354
  927. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-blas.h +0 -25
  928. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +0 -145
  929. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-metal.h +0 -66
  930. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +0 -256
  931. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +0 -2492
  932. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/gguf.h +0 -202
  933. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -1391
  934. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Info.plist +0 -35
  935. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Modules/module.modulemap +0 -17
  936. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  937. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  938. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  939. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5558
  940. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-alloc.h +0 -76
  941. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +0 -354
  942. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-blas.h +0 -25
  943. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +0 -145
  944. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-metal.h +0 -66
  945. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-opt.h +0 -256
  946. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +0 -2492
  947. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/gguf.h +0 -202
  948. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +0 -1391
  949. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Info.plist +0 -32
  950. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Modules/module.modulemap +0 -17
  951. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  952. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Info.plist +0 -20
  953. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  954. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +0 -5520
  955. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +0 -4243
  956. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-alloc.h +0 -76
  957. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +0 -354
  958. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-blas.h +0 -25
  959. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +0 -145
  960. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-metal.h +0 -66
  961. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-opt.h +0 -256
  962. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +0 -2492
  963. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/gguf.h +0 -202
  964. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +0 -1391
  965. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Info.plist +0 -32
  966. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Modules/module.modulemap +0 -17
  967. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  968. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  969. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  970. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  971. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  972. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +0 -0
  973. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +0 -0
  974. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  975. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  976. /package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -0,0 +1,2223 @@
1
+ #pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
2
+ #pragma clang diagnostic ignored "-Wunused-function"
3
+ #pragma clang diagnostic ignored "-Wunused-variable"
4
+ #pragma clang diagnostic ignored "-Wunused-but-set-variable"
5
+
6
+ #ifdef HTP_DEBUG
7
+ # define FARF_HIGH 1
8
+ #endif
9
+
10
+ #include <HAP_farf.h>
11
+ #include <HAP_mem.h>
12
+ #include <HAP_perf.h>
13
+ #include <HAP_ps.h>
14
+ #include <hexagon_protos.h>
15
+ #include <hexagon_types.h>
16
+ #include <math.h>
17
+ #include <qurt_thread.h>
18
+ #include <string.h>
19
+
20
+ #define GGML_COMMON_DECL_C
21
+ #include "ggml-common.h"
22
+ #include "htp-ctx.h"
23
+ #include "htp-dma.h"
24
+ #include "htp-msg.h"
25
+ #include "htp-ops.h"
26
+ #include "hvx-utils.h"
27
+ #include "ops-utils.h"
28
+
29
+ struct htp_matmul_type {
30
+ const char * type;
31
+ void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
32
+ void (*vec_dot_rx2)(const int n,
33
+ float * restrict s,
34
+ const void * restrict vx,
35
+ uint32_t vx_row_size,
36
+ const void * restrict vy);
37
+ };
38
+
39
+ typedef struct {
40
+ HVX_Vector v[2];
41
+ } HVX_Vector_x2;
42
+
43
+ typedef struct {
44
+ HVX_Vector v[4];
45
+ } HVX_Vector_x4;
46
+
47
+ typedef struct {
48
+ HVX_Vector v[8];
49
+ } HVX_Vector_x8;
50
+
51
+ // vdelta control to replicate first 4x fp32 values across lanes
52
+ static const uint8_t __attribute__((aligned(128))) repl_4x_fp32[128] = {
53
+ 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
54
+ 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
55
+ 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
56
+ 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
57
+ 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
58
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
59
+ 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
60
+ };
61
+
62
+ // vdelta control to replicate and interleave first 8x fp32 values across lanes
63
+ static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_fp32[128] = {
64
+ 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
65
+ 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
66
+ 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
67
+ 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
68
+ 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
69
+ 0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
70
+ 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
71
+ };
72
+
73
+ // vdelta control to replicate first fp32 value across all elements
74
+ static const uint8_t __attribute__((aligned(128))) repl_1x_fp32[128] = {
75
+ 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
76
+ 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
77
+ 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
78
+ 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
79
+ 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
80
+ 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
81
+ 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
82
+ };
83
+
84
+ // vdelta control to replicate first fp16 value across all elements
85
+ static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = {
86
+ 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
87
+ 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
88
+ 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
89
+ 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
90
+ 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
91
+ 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
92
+ 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
93
+ };
94
+
95
+ // vdelta control to expand first 32 e8m0 values into 32 uint32 elements
96
+ static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
97
+ 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
98
+ 0x00, 0x11, 0x10, 0x10, 0x10, 0x02, 0x00, 0x04, 0x00, 0x01, 0x02, 0x08, 0x08, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04,
99
+ 0x00, 0x00, 0x22, 0x20, 0x20, 0x20, 0x21, 0x22, 0x20, 0x24, 0x04, 0x00, 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x02,
100
+ 0x00, 0x04, 0x00, 0x11, 0x12, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08,
101
+ 0x01, 0x02, 0x00, 0x04, 0x44, 0x40, 0x40, 0x40, 0x41, 0x40, 0x40, 0x40, 0x42, 0x40, 0x44, 0x40, 0x41, 0x42, 0x48,
102
+ 0x48, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x12, 0x10, 0x10, 0x10, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00,
103
+ 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20,
104
+ };
105
+
106
+ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
107
+ 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0,
108
+ 0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
109
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
110
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
111
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
112
+ };
113
+
114
+ // q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales
115
+
116
+ static inline size_t q8x4x2_row_size(uint32_t ne) {
117
+ // ensures perfect alignment of quants and full row
118
+ const uint32_t qk = QK_Q8_0x4x2;
119
+ const uint32_t nb = (ne + qk - 1) / qk;
120
+ return htp_round_up(ne + nb * 8 * sizeof(__fp16), 128);
121
+ }
122
+
123
+ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
124
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
125
+
126
+ HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
127
+ HVX_Vector v2_3 = vptr[1]; // ...
128
+ HVX_Vector v4_5 = vptr[2]; // ...
129
+ HVX_Vector v6_7 = vptr[3]; // ...
130
+
131
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
132
+
133
+ HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
134
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
135
+ HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
136
+ HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
137
+ HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
138
+ HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
139
+ HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
140
+ HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
141
+
142
+ // Convert uint4 to int4 (i.e. x - 8)
143
+ const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
144
+ v0 = Q6_Vb_vsub_VbVb(v0, i8);
145
+ v1 = Q6_Vb_vsub_VbVb(v1, i8);
146
+ v2 = Q6_Vb_vsub_VbVb(v2, i8);
147
+ v3 = Q6_Vb_vsub_VbVb(v3, i8);
148
+ v4 = Q6_Vb_vsub_VbVb(v4, i8);
149
+ v5 = Q6_Vb_vsub_VbVb(v5, i8);
150
+ v6 = Q6_Vb_vsub_VbVb(v6, i8);
151
+ v7 = Q6_Vb_vsub_VbVb(v7, i8);
152
+
153
+ HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
154
+ return r;
155
+ }
156
+
157
+ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) {
158
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
159
+
160
+ HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
161
+ HVX_Vector v2_3 = vptr[1]; // ...
162
+ HVX_Vector v4_5 = vptr[2]; // ...
163
+ HVX_Vector v6_7 = vptr[3]; // ...
164
+
165
+ const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
166
+
167
+ HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
168
+ HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
169
+ HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
170
+ HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
171
+ HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
172
+ HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
173
+ HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
174
+ HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
175
+
176
+ HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
177
+ v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
178
+ v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
179
+ v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
180
+ v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
181
+ v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
182
+ v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
183
+ v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
184
+ v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
185
+
186
+ HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
187
+ return r;
188
+ }
189
+
190
+ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
191
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
192
+
193
+ HVX_Vector v0 = vptr[0]; // first 128 vals
194
+ HVX_Vector v1 = vptr[1]; // ...
195
+ HVX_Vector v2 = vptr[2]; // ...
196
+ HVX_Vector v3 = vptr[3]; // ...
197
+ HVX_Vector v4 = vptr[4]; // ...
198
+ HVX_Vector v5 = vptr[5]; // ...
199
+ HVX_Vector v6 = vptr[6]; // ...
200
+ HVX_Vector v7 = vptr[7]; // ...
201
+
202
+ HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
203
+ return r;
204
+ }
205
+
206
+ static inline HVX_Vector_x4 hvx_vec_load_x4_f16(const uint8_t * restrict ptr) {
207
+ const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
208
+
209
+ HVX_Vector v0 = vptr[0]; // first 64 vals
210
+ HVX_Vector v1 = vptr[1]; // second 64 vals
211
+ HVX_Vector v2 = vptr[2]; // third 64 vals
212
+ HVX_Vector v3 = vptr[3]; // forth 64 vals
213
+
214
+ HVX_Vector_x4 r = { v0, v1, v2, v3 };
215
+ return r;
216
+ }
217
+
218
+ static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) {
219
+ const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr;
220
+
221
+ HVX_VectorPair v0 = vptr[0]; // first 64 vals
222
+ HVX_VectorPair v1 = vptr[1]; // second 64 vals
223
+ HVX_VectorPair v2 = vptr[2]; // third 64 vals
224
+ HVX_VectorPair v3 = vptr[3]; // forth 64 vals
225
+
226
+ HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero());
227
+ HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero());
228
+ HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero());
229
+ HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero());
230
+ HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero());
231
+ HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero());
232
+ HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero());
233
+ HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero());
234
+
235
+ HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo));
236
+ HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo));
237
+ HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo));
238
+ HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo));
239
+
240
+ // vcombine does a shuffle, use vdeal to undo
241
+
242
+ HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) };
243
+ return r;
244
+ }
245
+
246
+ // Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
247
+ // Accumulate each block into a single int32 value.
248
+ // Return a single HVX vector with 32x int32 accumulators.
249
+ // This version is parameterized to support less than 1024 elements.
250
+ // if() checks are optimized out at compile time -- make sure to pass N as a constexpr.
251
+
252
+ static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
253
+ HVX_Vector r0 = Q6_V_vsplat_R(0);
254
+ HVX_Vector r1 = Q6_V_vsplat_R(0);
255
+ HVX_Vector r2 = Q6_V_vsplat_R(0);
256
+ HVX_Vector r3 = Q6_V_vsplat_R(0);
257
+ HVX_Vector r4 = Q6_V_vsplat_R(0);
258
+ HVX_Vector r5 = Q6_V_vsplat_R(0);
259
+ HVX_Vector r6 = Q6_V_vsplat_R(0);
260
+ HVX_Vector r7 = Q6_V_vsplat_R(0);
261
+
262
+ HVX_VectorPair p3;
263
+ HVX_VectorPair p2;
264
+ HVX_VectorPair p1;
265
+ HVX_VectorPair p0;
266
+
267
+ if (n >= 128) { r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); }
268
+ if (n >= 256) { r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); }
269
+ if (n >= 384) { r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); }
270
+ if (n >= 512) { r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); }
271
+ if (n >= 640) { r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); }
272
+ if (n >= 768) { r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); }
273
+ if (n >= 896) { r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); }
274
+ if (n >= 1024) { r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); }
275
+
276
+ if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
277
+ if (n >= 384) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
278
+ if (n >= 640) { p2 = Q6_W_vdeal_VVR(r5, r4, -4); }
279
+ if (n >= 896) { p3 = Q6_W_vdeal_VVR(r7, r6, -4); }
280
+
281
+ if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
282
+ if (n >= 384) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
283
+ if (n >= 640) { r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); }
284
+ if (n >= 896) { r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); }
285
+
286
+ if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
287
+ if (n >= 640) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); }
288
+
289
+ if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
290
+ if (n >= 640) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); }
291
+
292
+ if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); }
293
+ if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); }
294
+
295
+ return r0;
296
+ }
297
+
298
+ static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) {
299
+ return hvx_vec_rmpy_x8_n(x, y, 1024);
300
+ }
301
+
302
+ // Handle most common cases of tensors not multiple of 1024.
303
+ static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
304
+ if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); };
305
+ if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); };
306
+ if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); };
307
+ return hvx_vec_rmpy_x8_n(x, y, 1024);
308
+ }
309
+
310
+ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
311
+ assert(n % 32 == 0); // min sub-block size
312
+ assert((unsigned long) vx % 128 == 0);
313
+ assert((unsigned long) vy % 128 == 0);
314
+
315
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
316
+
317
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
318
+ const uint32_t x_qblk_size = qk / 2; // int4
319
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
320
+
321
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
322
+ const uint32_t y_qblk_size = qk; // int8
323
+ const uint32_t y_qrow_size = n; // int8 (not padded)
324
+
325
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first
326
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales
327
+
328
+ const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
329
+ const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
330
+
331
+ // Row sum (qf32)
332
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
333
+
334
+ // Multiply and accumulate into int32.
335
+ // Compute combined scale (fp32).
336
+ // Apply scale to acc and accumulate into the row sum (qf32).
337
+
338
+ const uint32_t nb = n / qk; // num full blocks
339
+ const uint32_t nloe = n % qk; // num leftover elemements
340
+
341
+ uint32_t i = 0;
342
+ for (; i < nb; i++) {
343
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
344
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
345
+
346
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
347
+
348
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
349
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
350
+
351
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
352
+
353
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
354
+
355
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
356
+ }
357
+
358
+ // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
359
+ if (nloe) {
360
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
361
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
362
+
363
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
364
+
365
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
366
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
367
+
368
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
369
+
370
+ // Zero out unused scales
371
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
372
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
373
+
374
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
375
+
376
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
377
+ }
378
+
379
+ // Reduce and convert into fp32
380
+ r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
381
+
382
+ hvx_vec_store_u(&s[0], 4, r0_sum);
383
+ }
384
+
385
+ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
386
+ float * restrict s,
387
+ const void * restrict vx,
388
+ uint32_t vx_row_size,
389
+ const void * restrict vy) {
390
+ assert(n % 32 == 0); // min sub-block size
391
+ assert((unsigned long) vx % 128 == 0);
392
+ assert((unsigned long) vy % 128 == 0);
393
+
394
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
395
+
396
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
397
+ const uint32_t x_qblk_size = qk / 2; // int4
398
+ const uint32_t x_qrow_size = n / 2; // int4 (not padded)
399
+
400
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
401
+ const uint32_t y_qblk_size = qk; // int8
402
+ const uint32_t y_qrow_size = n; // int8 (not padded)
403
+
404
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
405
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
406
+
407
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
408
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
409
+
410
+ const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
411
+ const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
412
+
413
+ // Row sum (qf32)
414
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
415
+ HVX_Vector r1_sum = Q6_V_vsplat_R(0);
416
+
417
+ // Multiply and accumulate into int32.
418
+ // Compute combined scale (fp32).
419
+ // Apply scale to acc and accumulate into the row sum (qf32).
420
+
421
+ const uint32_t nb = n / qk; // num full blocks
422
+ const uint32_t nloe = n % qk; // num leftover elemements
423
+
424
+ uint32_t i = 0;
425
+ for (; i < nb; i++) {
426
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
427
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
428
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
429
+
430
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
431
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
432
+
433
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
434
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
435
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
436
+
437
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
438
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
439
+
440
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
441
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
442
+
443
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
444
+ r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
445
+ }
446
+
447
+ // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
448
+ if (nloe) {
449
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
450
+ HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
451
+ HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
452
+
453
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
454
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
455
+
456
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
457
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
458
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
459
+
460
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
461
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
462
+
463
+ // Zero out unused scales
464
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
465
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
466
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
467
+
468
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
469
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
470
+
471
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
472
+ r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
473
+ }
474
+
475
+ // Convert into fp32 and reduce
476
+ r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
477
+ r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
478
+ HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
479
+
480
+ hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
481
+ }
482
+
483
+ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
484
+ assert(n % 32 == 0); // min sub-block size
485
+ assert((unsigned long) vx % 128 == 0);
486
+ assert((unsigned long) vy % 128 == 0);
487
+
488
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
489
+
490
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
491
+ const uint32_t x_qblk_size = qk; // int8
492
+ const uint32_t x_qrow_size = n; // int8 (not padded)
493
+
494
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
495
+ const uint32_t y_qblk_size = qk; // int8
496
+ const uint32_t y_qrow_size = n; // int8 (not padded)
497
+
498
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first
499
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales
500
+
501
+ const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
502
+ const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
503
+
504
+ // Row sum (qf32)
505
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
506
+
507
+ // Multiply and accumulate into int32.
508
+ // Compute combined scale (fp32).
509
+ // Apply scale to acc and accumulate into the row sum (qf32).
510
+
511
+ const uint32_t nb = n / qk; // num full blocks
512
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
513
+
514
+ uint32_t i = 0;
515
+ for (; i < nb; i++) {
516
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
517
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
518
+
519
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
520
+
521
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
522
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
523
+
524
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
525
+
526
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
527
+
528
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
529
+ }
530
+
531
+ // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
532
+ if (nloe) {
533
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
534
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
535
+
536
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
537
+
538
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
539
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
540
+
541
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
542
+
543
+ // Zero out unused scales
544
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
545
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
546
+
547
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
548
+
549
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
550
+ }
551
+
552
+ // Reduce and convert into fp32
553
+ r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
554
+
555
+ hvx_vec_store_u(&s[0], 4, r0_sum);
556
+ }
557
+
558
+ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
559
+ float * restrict s,
560
+ const void * restrict vx,
561
+ uint32_t vx_row_size,
562
+ const void * restrict vy) {
563
+ assert(n % 32 == 0); // min sub-block size
564
+ assert((unsigned long) vx % 128 == 0);
565
+ assert((unsigned long) vy % 128 == 0);
566
+
567
+ const uint32_t qk = QK_Q4_0x4x2 * 4;
568
+
569
+ const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
570
+ const uint32_t x_qblk_size = qk; // int8
571
+ const uint32_t x_qrow_size = n; // int8 (not padded)
572
+
573
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
574
+ const uint32_t y_qblk_size = qk; // int8
575
+ const uint32_t y_qrow_size = n; // int8 (not padded)
576
+
577
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
578
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
579
+
580
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
581
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
582
+
583
+ const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
584
+ const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
585
+
586
+ // Row sum (qf32)
587
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
588
+ HVX_Vector r1_sum = Q6_V_vsplat_R(0);
589
+
590
+ // Multiply and accumulate into int32.
591
+ // Compute combined scale (fp32).
592
+ // Apply scale to acc and accumulate into the row sum (qf32).
593
+
594
+ const uint32_t nb = n / qk; // num full blocks
595
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
596
+
597
+ uint32_t i = 0;
598
+ for (; i < nb; i++) {
599
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
600
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
601
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
602
+
603
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
604
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
605
+
606
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
607
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
608
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
609
+
610
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
611
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
612
+
613
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
614
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
615
+
616
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
617
+ r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
618
+ }
619
+
620
+ // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
621
+ if (nloe) {
622
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
623
+ HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
624
+ HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
625
+
626
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
627
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
628
+
629
+ HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
630
+ HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
631
+ HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
632
+
633
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
634
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
635
+
636
+ // Zero out unused scales
637
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
638
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
639
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
640
+
641
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
642
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
643
+
644
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
645
+ r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
646
+ }
647
+
648
+ // Convert into fp32 and reduce
649
+ r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
650
+ r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
651
+ HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
652
+
653
+ hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
654
+ }
655
+
656
+ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
657
+ float * restrict s,
658
+ const void * restrict vx,
659
+ const void * restrict vy) {
660
+ assert(n % 32 == 0); // min sub-block size
661
+ assert((unsigned long) vx % 128 == 0);
662
+ assert((unsigned long) vy % 128 == 0);
663
+
664
+ const uint32_t qk = QK_MXFP4x4x2 * 4;
665
+
666
+ const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
667
+ const uint32_t x_qblk_size = qk / 2; // fp4
668
+ const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
669
+
670
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
671
+ const uint32_t y_qblk_size = qk; // int8
672
+ const uint32_t y_qrow_size = n; // int8 (not padded)
673
+
674
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first
675
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales
676
+
677
+ const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
678
+ const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
679
+
680
+ // Row sum (qf32)
681
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
682
+
683
+ // Multiply and accumulate into int32.
684
+ // Compute combined scale (fp32).
685
+ // Apply scale to acc and accumulate into the row sum (qf32).
686
+
687
+ const uint32_t nb = n / qk; // num full blocks
688
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
689
+
690
+ uint32_t i = 0;
691
+ for (; i < nb; i++) {
692
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
693
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
694
+
695
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
696
+
697
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
698
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
699
+
700
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
701
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
702
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
703
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
704
+
705
+ // Convert rX_d scales from e8m0 to fp32
706
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
707
+ // Left shift with zero fill to create FP32
708
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
709
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
710
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
711
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
712
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
713
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
714
+
715
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
716
+
717
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
718
+
719
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
720
+ }
721
+
722
+ // Process leftovers
723
+ if (nloe) {
724
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
725
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
726
+
727
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
728
+
729
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
730
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
731
+
732
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
733
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
734
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
735
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
736
+
737
+ // Convert rX_d scales from e8m0 to fp32
738
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
739
+ // Left shift with zero fill to create FP32
740
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
741
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
742
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
743
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
744
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
745
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
746
+
747
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
748
+
749
+ // Zero-out unused scales
750
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
751
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
752
+
753
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
754
+
755
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
756
+ }
757
+
758
+ // Reduce and convert into fp32
759
+ r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
760
+
761
+ hvx_vec_store_u(&s[0], 4, r0_sum);
762
+ }
763
+
764
+ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
765
+ float * restrict s,
766
+ const void * restrict vx,
767
+ uint32_t vx_row_size,
768
+ const void * restrict vy) {
769
+ assert(n % 32 == 0); // min sub-block size
770
+ assert((unsigned long) vx % 128 == 0);
771
+ assert((unsigned long) vy % 128 == 0);
772
+
773
+ const uint32_t qk = QK_MXFP4x4x2 * 4;
774
+
775
+ const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
776
+ const uint32_t x_qblk_size = qk / 2; // fp4
777
+ const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
778
+
779
+ const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
780
+ const uint32_t y_qblk_size = qk; // int8
781
+ const uint32_t y_qrow_size = n; // int8 (not padded)
782
+
783
+ const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
784
+ const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
785
+
786
+ const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
787
+ const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
788
+
789
+ const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
790
+ const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
791
+
792
+ // Row sum (qf32)
793
+ HVX_Vector r0_sum = Q6_V_vsplat_R(0);
794
+ HVX_Vector r1_sum = Q6_V_vsplat_R(0);
795
+
796
+ // Multiply and accumulate into int32.
797
+ // Compute combined scale (fp32).
798
+ // Apply scale to acc and accumulate into the row sum (qf32).
799
+
800
+ const uint32_t nb = n / qk; // num full blocks
801
+ int32_t nloe = n % qk; // num leftover elemements (must be signed)
802
+
803
+ uint32_t i = 0;
804
+ for (; i < nb; i++) {
805
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
806
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
807
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
808
+
809
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
810
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
811
+
812
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
813
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
814
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
815
+
816
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
817
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
818
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
819
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
820
+
821
+ // Convert rX_d scales from e8m0 to fp32
822
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
823
+ // Left shift with zero fill to create FP32
824
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
825
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
826
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
827
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
828
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
829
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
830
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
831
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
832
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
833
+
834
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
835
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
836
+
837
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
838
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
839
+
840
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
841
+ r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
842
+ }
843
+
844
+ // Process leftovers
845
+ if (nloe) {
846
+ HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
847
+ HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
848
+ HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
849
+
850
+ HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
851
+ HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
852
+
853
+ HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
854
+ HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
855
+ HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
856
+
857
+ // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
858
+ HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
859
+ vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half));
860
+ vy_d = Q6_Vsf_equals_Vqf32(vy_d);
861
+
862
+ // Convert rX_d scales from e8m0 to fp32
863
+ // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
864
+ // Left shift with zero fill to create FP32
865
+ // FIXME: might need to handle zero as a special case (see ggml-cpu code)
866
+ HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
867
+ HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
868
+ r0_d = Q6_V_vdelta_VV(r0_d, expand);
869
+ r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
870
+ r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
871
+ r1_d = Q6_V_vdelta_VV(r1_d, expand);
872
+ r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
873
+ r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
874
+
875
+ HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
876
+ HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
877
+
878
+ // Zero-out unused scales
879
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
880
+ r0_dd = Q6_V_vand_QV(bmask, r0_dd);
881
+ r1_dd = Q6_V_vand_QV(bmask, r1_dd);
882
+
883
+ HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
884
+ HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
885
+
886
+ r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
887
+ r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
888
+ }
889
+
890
+ // Convert into fp32 and reduce
891
+ r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
892
+ r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
893
+ HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
894
+
895
+ hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
896
+ }
897
+
898
+ #if 1
899
+ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
900
+ if (0) {
901
+ float rsum = 0;
902
+ const __fp16 * restrict vx = (const __fp16 * restrict) x;
903
+ const float * restrict vy = (const float * restrict) y;
904
+
905
+ for (uint32_t i = 0; i < n; i++) {
906
+ rsum += vx[i] * (__fp16) vy[i];
907
+ }
908
+ *s = rsum;
909
+ return;
910
+ }
911
+
912
+ const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
913
+ const HVX_UVectorPair * restrict vy = (const HVX_UVectorPair * restrict) y;
914
+
915
+ uint32_t nv0 = n / 64; // num full fp16 hvx vectors
916
+ uint32_t nv1 = n % 64; // leftover elements
917
+
918
+ // for some reason we need volatile here so that the compiler doesn't try anything funky
919
+ volatile HVX_Vector rsum = Q6_V_vsplat_R(0);
920
+
921
+ uint32_t i = 0;
922
+
923
+ for (i = 0; i < nv0; i++) {
924
+ HVX_VectorPair yp = vy[i];
925
+
926
+ HVX_Vector x = vx[i];
927
+ HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
928
+
929
+ HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
930
+ HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
931
+
932
+ HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo);
933
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
934
+ }
935
+
936
+ if (nv1) {
937
+ HVX_VectorPair yp = vy[i];
938
+
939
+ HVX_Vector x = vx[i];
940
+ HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
941
+
942
+ if (nv1 >= 32) {
943
+ HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
944
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
945
+ nv1 -= 32;
946
+ }
947
+
948
+ rsum = hvx_vec_qf32_reduce_sum(rsum);
949
+
950
+ if (nv1) {
951
+ HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
952
+ HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1);
953
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
954
+ }
955
+
956
+ // hvx_vec_dump_fp16("X", x);
957
+ // hvx_vec_dump_fp16("Y", y);
958
+ // hvx_vec_dump_fp32("SUM", Q6_Vsf_equals_Vqf32(sum));
959
+ // hvx_vec_dump_fp32("RSUM", Q6_Vsf_equals_Vqf32(rsum));
960
+ } else {
961
+ rsum = hvx_vec_qf32_reduce_sum(rsum);
962
+ }
963
+
964
+ *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum));
965
+
966
+ # ifdef HTP_DEBUG
967
+ {
968
+ float rsum = 0;
969
+ const __fp16 * restrict vx = (const __fp16 * restrict) x;
970
+ const float * restrict vy = (const float * restrict) y;
971
+
972
+ for (uint32_t i = 0; i < n; i++) {
973
+ rsum += vx[i] * vy[i];
974
+ }
975
+
976
+ float diff = fabs(*s - rsum);
977
+ if (diff > 0.001) {
978
+ FARF(HIGH, "vec-dot-f16-missmatch: %u (%u:%u) expected %.6f got %.6f\n", n, nv0, nv1, rsum, *s);
979
+ // htp_dump_f16("x", vx, n);
980
+ // htp_dump_f32("y", vy, n);
981
+ }
982
+ }
983
+ # endif
984
+ }
985
+ #else
986
+ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
987
+ const uint32_t fk = 64;
988
+ const uint32_t nb = n / fk;
989
+
990
+ assert(n % fk == 0);
991
+ assert(nb % 4 == 0);
992
+
993
+ const uint32_t x_blk_size = 2 * fk; // fp16
994
+ const uint32_t y_blk_size = 4 * fk; // fp32
995
+
996
+ // Row sum (qf32)
997
+ HVX_Vector rsum0 = Q6_V_vsplat_R(0);
998
+ HVX_Vector rsum1 = Q6_V_vsplat_R(0);
999
+ HVX_Vector rsum2 = Q6_V_vsplat_R(0);
1000
+ HVX_Vector rsum3 = Q6_V_vsplat_R(0);
1001
+
1002
+ for (uint32_t i = 0; i < nb; i += 4) {
1003
+ HVX_Vector_x4 vx = hvx_vec_load_x4_f16(x + (i * x_blk_size));
1004
+ HVX_Vector_x4 vy = hvx_vec_load_x4_f32_as_f16(y + (i * y_blk_size));
1005
+
1006
+ HVX_VectorPair fa0 = Q6_Wqf32_vmpy_VhfVhf(vx.v[0], vy.v[0]);
1007
+ HVX_VectorPair fa1 = Q6_Wqf32_vmpy_VhfVhf(vx.v[1], vy.v[1]);
1008
+ HVX_VectorPair fa2 = Q6_Wqf32_vmpy_VhfVhf(vx.v[2], vy.v[2]);
1009
+ HVX_VectorPair fa3 = Q6_Wqf32_vmpy_VhfVhf(vx.v[3], vy.v[3]);
1010
+
1011
+ rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa0), Q6_V_hi_W(fa0)));
1012
+ rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa1), Q6_V_hi_W(fa1)));
1013
+ rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa2), Q6_V_hi_W(fa2)));
1014
+ rsum3 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum3, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa3), Q6_V_hi_W(fa3)));
1015
+ }
1016
+
1017
+ // Reduce and convert into fp32
1018
+ rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum1);
1019
+ rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, rsum3);
1020
+ HVX_Vector rsum = hvx_vec_qf32_reduce_sum(Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum2));
1021
+ hvx_vec_store_u(s, 4, Q6_Vsf_equals_Vqf32(rsum));
1022
+ }
1023
+ #endif
1024
+
1025
+ #define htp_matmul_preamble \
1026
+ const uint32_t ne00 = src0->ne[0]; \
1027
+ const uint32_t ne01 = src0->ne[1]; \
1028
+ const uint32_t ne02 = src0->ne[2]; \
1029
+ const uint32_t ne03 = src0->ne[3]; \
1030
+ \
1031
+ const uint32_t ne10 = src1->ne[0]; \
1032
+ const uint32_t ne11 = src1->ne[1]; \
1033
+ const uint32_t ne12 = src1->ne[2]; \
1034
+ const uint32_t ne13 = src1->ne[3]; \
1035
+ \
1036
+ const uint32_t ne0 = dst->ne[0]; \
1037
+ const uint32_t ne1 = dst->ne[1]; \
1038
+ const uint32_t ne2 = dst->ne[2]; \
1039
+ const uint32_t ne3 = dst->ne[3]; \
1040
+ \
1041
+ const uint32_t nb00 = src0->nb[0]; \
1042
+ const uint32_t nb01 = src0->nb[1]; \
1043
+ const uint32_t nb02 = src0->nb[2]; \
1044
+ const uint32_t nb03 = src0->nb[3]; \
1045
+ \
1046
+ const uint32_t nb10 = src1->nb[0]; \
1047
+ const uint32_t nb11 = src1->nb[1]; \
1048
+ const uint32_t nb12 = src1->nb[2]; \
1049
+ const uint32_t nb13 = src1->nb[3]; \
1050
+ \
1051
+ const uint32_t nb0 = dst->nb[0]; \
1052
+ const uint32_t nb1 = dst->nb[1]; \
1053
+ const uint32_t nb2 = dst->nb[2]; \
1054
+ const uint32_t nb3 = dst->nb[3];
1055
+
1056
+ // q8x4 src1 tensor is already in VTCM spad
1057
+ static void matmul(struct htp_matmul_type * mt,
1058
+ struct htp_tensor * restrict src0,
1059
+ struct htp_tensor * restrict src1,
1060
+ struct htp_tensor * restrict dst,
1061
+ struct htp_spad * restrict src0_spad,
1062
+ struct htp_spad * restrict src1_spad,
1063
+ struct htp_spad * restrict dst_spad,
1064
+ uint32_t nth,
1065
+ uint32_t ith,
1066
+ uint32_t src0_nrows_per_thread,
1067
+ dma_queue * dma_queue) {
1068
+ htp_matmul_preamble;
1069
+
1070
+ const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
1071
+ const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
1072
+
1073
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
1074
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1075
+ const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1076
+
1077
+ // no work for this thread
1078
+ if (src0_start_row >= src0_end_row) {
1079
+ return;
1080
+ }
1081
+
1082
+ const size_t dst_row_size = nb1;
1083
+ const size_t src0_row_size = nb01;
1084
+ const size_t src1_row_size = q8x4x2_row_size(ne10);
1085
+
1086
+ const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
1087
+
1088
+ // Per-thread VTCM scratchpads for all tensors
1089
+ // Note that the entire src1 tensor is already in VTCM
1090
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
1091
+ uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1092
+ uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1093
+ uint8_t * restrict src1_data = src1_spad->data;
1094
+
1095
+ volatile uint64_t t1, t2;
1096
+ t1 = HAP_perf_get_qtimer_count();
1097
+
1098
+ const uint8_t * restrict src0_row = (const uint8_t *) src0->data;
1099
+
1100
+ // Prefill spad with src0 rows
1101
+ #pragma unroll(4)
1102
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1103
+ const int is0 = (ir0 - src0_start_row);
1104
+ if (is0 >= HTP_SPAD_SRC0_NROWS) {
1105
+ break;
1106
+ }
1107
+ dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
1108
+ src0_row_size_padded, src0_row_size, 2);
1109
+ }
1110
+
1111
+ // Process src0 rows
1112
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1113
+ const uint8_t * ss0 = dma_queue_pop(dma_queue);
1114
+
1115
+ #pragma unroll(2)
1116
+ for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
1117
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size);
1118
+ float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
1119
+ mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
1120
+ }
1121
+
1122
+ // Prefetch next (n + spad_nrows) row
1123
+ const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
1124
+ const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
1125
+ if (pr0 < src0_end_row_x2) {
1126
+ dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size,
1127
+ src0_row_size_padded, src0_row_size, 2);
1128
+ }
1129
+ }
1130
+
1131
+ // Process the last row (if any)
1132
+ if (src0_end_row != src0_end_row_x2) {
1133
+ uint32_t ir0 = src0_end_row_x2;
1134
+ const int is0 = (ir0 - src0_start_row);
1135
+ dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
1136
+ src0_row_size_padded, src0_row_size, 1);
1137
+ const uint8_t * ss0 = dma_queue_pop(dma_queue);
1138
+
1139
+ #pragma unroll(2)
1140
+ for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
1141
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size);
1142
+ float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
1143
+ mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
1144
+ }
1145
+ }
1146
+
1147
+ t2 = HAP_perf_get_qtimer_count();
1148
+
1149
+ FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
1150
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
1151
+ src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1152
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1153
+ }
1154
+
1155
+ // q8x4x2 src1 tensor is already in VTCM spad
1156
+ static void matvec(struct htp_matmul_type * mt,
1157
+ struct htp_tensor * restrict src0,
1158
+ struct htp_tensor * restrict src1,
1159
+ struct htp_tensor * restrict dst,
1160
+ struct htp_spad * restrict src0_spad,
1161
+ struct htp_spad * restrict src1_spad,
1162
+ struct htp_spad * restrict dst_spad,
1163
+ uint32_t nth,
1164
+ uint32_t ith,
1165
+ uint32_t src0_nrows_per_thread,
1166
+ dma_queue * dma_queue) {
1167
+ htp_matmul_preamble;
1168
+
1169
+ const uint32_t src0_nrows = ne01;
1170
+
1171
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
1172
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1173
+ const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1174
+
1175
+ // no work for this thread
1176
+ if (src0_start_row >= src0_end_row) {
1177
+ return;
1178
+ }
1179
+
1180
+ const size_t dst_row_size = nb1;
1181
+ const size_t src0_row_size = nb01;
1182
+ const size_t src1_row_size = q8x4x2_row_size(ne10);
1183
+
1184
+ const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
1185
+
1186
+ // Per-thread VTCM scratchpads for all tensors
1187
+ // Note that the entire src1 tensor is already in VTCM
1188
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
1189
+ uint8_t * spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1190
+ uint8_t * spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1191
+ uint8_t * src1_data = src1_spad->data;
1192
+
1193
+ uint64_t t1, t2;
1194
+ t1 = HAP_perf_get_qtimer_count();
1195
+
1196
+ float * tmp = (float *) spad_dst;
1197
+
1198
+ const uint8_t * restrict src0_row = (const uint8_t *) src0->data;
1199
+ const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
1200
+ float * restrict dst_col = (float *) dst->data;
1201
+
1202
+ // Prefill spad with 2x src0 rows
1203
+ #pragma unroll(2)
1204
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1205
+ const uint32_t is0 = (ir0 - src0_start_row);
1206
+ if (is0 >= HTP_SPAD_SRC0_NROWS) {
1207
+ break;
1208
+ }
1209
+ dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
1210
+ src0_row_size_padded, src0_row_size, 2);
1211
+ }
1212
+
1213
+ // Process src0 rows
1214
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1215
+ const uint8_t * ss0 = dma_queue_pop(dma_queue);
1216
+ mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_row_size_padded, src1_col);
1217
+
1218
+ // Prefetch next (n + spad_nrows) row
1219
+ const uint32_t pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
1220
+ const uint32_t is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
1221
+ if (pr0 < src0_end_row_x2) {
1222
+ dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size,
1223
+ src0_row_size_padded, src0_row_size, 2);
1224
+ }
1225
+ }
1226
+
1227
+ // Process the last row (if any)
1228
+ if (src0_end_row != src0_end_row_x2) {
1229
+ const uint32_t ir0 = src0_end_row_x2;
1230
+ const uint32_t is0 = (ir0 - src0_start_row);
1231
+ dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
1232
+ src0_row_size_padded, src0_row_size, 1);
1233
+ const uint8_t * ss0 = dma_queue_pop(dma_queue);
1234
+ mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
1235
+ }
1236
+
1237
+ hvx_copy_fp32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
1238
+
1239
+ t2 = HAP_perf_get_qtimer_count();
1240
+
1241
+ FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
1242
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
1243
+ src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1244
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1245
+ }
1246
+
1247
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ids->ne[0] * ids->ne[1] + (i1)]
1248
+
1249
+ struct mmid_row_mapping {
1250
+ uint32_t i1;
1251
+ uint32_t i2;
1252
+ };
1253
+
1254
+ // q8x4 src1 tensor is already in VTCM spad
1255
+ static void matmul_id(struct htp_matmul_type * mt,
1256
+ struct htp_tensor * restrict src0,
1257
+ struct htp_tensor * restrict src1,
1258
+ struct htp_tensor * restrict ids,
1259
+ struct htp_tensor * restrict dst,
1260
+ struct htp_spad * restrict src0_spad,
1261
+ struct htp_spad * restrict src1_spad,
1262
+ struct htp_spad * restrict src2_spad,
1263
+ struct htp_spad * restrict dst_spad,
1264
+ uint32_t nth,
1265
+ uint32_t ith,
1266
+ uint32_t src0_nrows_per_thread,
1267
+ dma_queue * dma_queue) {
1268
+ htp_matmul_preamble;
1269
+
1270
+ uint64_t t1, t2;
1271
+ t1 = HAP_perf_get_qtimer_count();
1272
+
1273
+ const uint32_t src0_nrows = ne01; // src0 rows per expert
1274
+ const uint32_t src1_nrows = ne11;
1275
+
1276
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
1277
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1278
+ const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1279
+
1280
+ // no work for this thread
1281
+ if (src0_start_row >= src0_end_row) {
1282
+ return;
1283
+ }
1284
+
1285
+ const uint32_t n_ids = ids->ne[0]; // n_expert_used
1286
+ const uint32_t n_as = ne02; // n_expert
1287
+
1288
+ const size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
1289
+ const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
1290
+
1291
+ const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0;
1292
+ const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size;
1293
+
1294
+ const size_t dst_row_size = nb1;
1295
+ const size_t src0_row_size = nb01;
1296
+ const size_t src1_row_size = q8x4x2_row_size(ne10);
1297
+
1298
+ const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
1299
+
1300
+ // Per-thread VTCM scratchpads for all tensors
1301
+ // Note that the entire src1 tensor is already in VTCM
1302
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
1303
+ uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1304
+ uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1305
+ uint8_t * restrict src1_data = src1_spad->data;
1306
+
1307
+ for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) {
1308
+ const int32_t cne1 = matrix_row_counts[cur_a];
1309
+
1310
+ if (cne1 == 0) {
1311
+ continue;
1312
+ }
1313
+
1314
+ const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0);
1315
+
1316
+ // Prefill spad with src0 rows
1317
+ #pragma unroll(4)
1318
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1319
+ const int is0 = (ir0 - src0_start_row);
1320
+ if (is0 >= HTP_SPAD_SRC0_NROWS) {
1321
+ break;
1322
+ }
1323
+ dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
1324
+ src0_row_size_padded, src0_row_size, 2);
1325
+ }
1326
+
1327
+ // Process src0 rows
1328
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1329
+ const uint8_t * ss0 = dma_queue_pop(dma_queue);
1330
+
1331
+ for (uint32_t cid = 0; cid < cne1; ++cid) {
1332
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
1333
+ const int rm1 = row_mapping.i1; // expert idx
1334
+ const int rm2 = row_mapping.i2; // token idx
1335
+
1336
+ const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
1337
+ const uint8_t * restrict src1_col =
1338
+ (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
1339
+ float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
1340
+
1341
+ mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
1342
+ }
1343
+
1344
+ // Prefetch next (n + spad_nrows) row
1345
+ const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
1346
+ const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
1347
+ if (pr0 < src0_end_row_x2) {
1348
+ dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size,
1349
+ src0_row_size_padded, src0_row_size, 2);
1350
+ }
1351
+ }
1352
+
1353
+ // Process the last row (if any)
1354
+ if (src0_end_row != src0_end_row_x2) {
1355
+ uint32_t ir0 = src0_end_row_x2;
1356
+ const uint32_t is0 = (ir0 - src0_start_row);
1357
+ dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
1358
+ src0_row_size_padded, src0_row_size, 1);
1359
+ const uint8_t * ss0 = dma_queue_pop(dma_queue);
1360
+
1361
+ for (uint32_t cid = 0; cid < cne1; ++cid) {
1362
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
1363
+ const int rm1 = row_mapping.i1; // expert idx
1364
+ const int rm2 = row_mapping.i2; // token idx
1365
+
1366
+ const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx
1367
+ const uint8_t * restrict src1_col =
1368
+ (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
1369
+ float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
1370
+
1371
+ mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
1372
+ }
1373
+ }
1374
+ }
1375
+
1376
+ t2 = HAP_perf_get_qtimer_count();
1377
+
1378
+ FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
1379
+ ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
1380
+ src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1],
1381
+ dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1382
+ }
1383
+
1384
+ // q8x4 src1 tensor is already in VTCM spad
1385
+ static void matvec_id(struct htp_matmul_type * mt,
1386
+ struct htp_tensor * restrict src0,
1387
+ struct htp_tensor * restrict src1,
1388
+ struct htp_tensor * restrict src2,
1389
+ struct htp_tensor * restrict dst,
1390
+ struct htp_spad * restrict src0_spad,
1391
+ struct htp_spad * restrict src1_spad,
1392
+ struct htp_spad * restrict src2_spad,
1393
+ struct htp_spad * restrict dst_spad,
1394
+ uint32_t nth,
1395
+ uint32_t ith,
1396
+ uint32_t src0_nrows_per_thread,
1397
+ dma_queue * dma_queue) {
1398
+ htp_matmul_preamble;
1399
+
1400
+ uint64_t t1, t2;
1401
+ t1 = HAP_perf_get_qtimer_count();
1402
+
1403
+ const uint32_t src0_nrows = ne01; // src0 rows per expert
1404
+
1405
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
1406
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
1407
+ const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U);
1408
+
1409
+ // no work for this thread
1410
+ if (src0_start_row >= src0_end_row) {
1411
+ return;
1412
+ }
1413
+
1414
+ assert(ne13 % ne03 == 0);
1415
+
1416
+ const size_t dst_row_size = nb1;
1417
+ const size_t src0_row_size = nb01;
1418
+ const size_t src1_row_size = q8x4x2_row_size(ne10);
1419
+
1420
+ const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
1421
+
1422
+ const uint32_t n_aids = src2->ne[0]; // num activated experts
1423
+ const uint32_t n_ids = ne02; // num experts
1424
+
1425
+ // Per-thread VTCM scratchpads for all tensors
1426
+ // Note that the entire src1 tensor is already in VTCM
1427
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
1428
+ uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith;
1429
+ uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
1430
+ uint8_t * restrict src1_data = src1_spad->data;
1431
+
1432
+ for (uint32_t ie1 = 0; ie1 < n_aids; ++ie1) { // for each expert
1433
+ const uint32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]);
1434
+ assert(eid < n_ids);
1435
+
1436
+ const uint8_t * restrict src0_row = (const uint8_t *) src0->data + eid * nb02;
1437
+ const uint8_t * restrict src1_col = (const uint8_t *) src1_data;
1438
+ float * restrict dst_row = (float *) (dst->data + ie1 * nb1);
1439
+
1440
+ // Prefill spad with src0 rows
1441
+ #pragma unroll(4)
1442
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1443
+ const int is0 = (ir0 - src0_start_row);
1444
+ if (is0 >= HTP_SPAD_SRC0_NROWS) {
1445
+ break;
1446
+ }
1447
+ dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
1448
+ src0_row_size_padded, src0_row_size, 2);
1449
+ }
1450
+
1451
+ // Process src0 rows
1452
+ for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
1453
+ const uint8_t * ss0 = dma_queue_pop(dma_queue);
1454
+ mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
1455
+
1456
+ // Prefetch next (n + spad_nrows) row
1457
+ const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
1458
+ const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
1459
+ if (pr0 < src0_end_row_x2) {
1460
+ dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size,
1461
+ src0_row_size_padded, src0_row_size, 2);
1462
+ }
1463
+ }
1464
+
1465
+ // Process the last row (if any)
1466
+ if (src0_end_row != src0_end_row_x2) {
1467
+ uint32_t ir0 = src0_end_row_x2;
1468
+ const uint32_t is0 = (ir0 - src0_start_row);
1469
+ dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size,
1470
+ src0_row_size_padded, src0_row_size, 1);
1471
+ const uint8_t * ss0 = dma_queue_pop(dma_queue);
1472
+ mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
1473
+ }
1474
+ }
1475
+
1476
+ t2 = HAP_perf_get_qtimer_count();
1477
+
1478
+ FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
1479
+ ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
1480
+ src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0],
1481
+ dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1482
+ }
1483
+
1484
+ // *** matmul in fp16
1485
+
1486
+ static void matmul_f16_f32(struct htp_tensor * restrict src0,
1487
+ struct htp_tensor * restrict src1,
1488
+ struct htp_tensor * restrict dst,
1489
+ struct htp_spad * restrict src0_spad,
1490
+ struct htp_spad * restrict src1_spad,
1491
+ struct htp_spad * restrict dst_spad,
1492
+ uint32_t nth,
1493
+ uint32_t ith,
1494
+ uint32_t src0_nrows_per_thread,
1495
+ dma_queue * dma_queue) {
1496
+ htp_matmul_preamble;
1497
+
1498
+ uint64_t t1, t2;
1499
+ t1 = HAP_perf_get_qtimer_count();
1500
+
1501
+ const size_t src0_row_size = sizeof(__fp16) * ne00;
1502
+ const size_t src1_row_size = sizeof(float) * ne10;
1503
+
1504
+ assert(ne12 % ne02 == 0);
1505
+ assert(ne13 % ne03 == 0);
1506
+
1507
+ // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
1508
+ const uint32_t nr0 = ne0;
1509
+
1510
+ // This is the size of the rest of the dimensions of the result
1511
+ const uint32_t nr1 = ne1 * ne2 * ne3;
1512
+
1513
+ uint32_t chunk_size = 64;
1514
+
1515
+ // distribute the thread work across the inner or outer loop based on which one is larger
1516
+ uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
1517
+ uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
1518
+
1519
+ // The number of elements in each chunk
1520
+ const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
1521
+ const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
1522
+
1523
+ uint32_t current_chunk = ith;
1524
+
1525
+ const uint32_t ith0 = current_chunk % nchunk0;
1526
+ const uint32_t ith1 = current_chunk / nchunk0;
1527
+
1528
+ const uint32_t ir0_start = dr0 * ith0;
1529
+ const uint32_t ir0_end = MIN(ir0_start + dr0, nr0);
1530
+
1531
+ const uint32_t ir1_start = dr1 * ith1;
1532
+ const uint32_t ir1_end = MIN(ir1_start + dr1, nr1);
1533
+
1534
+ // broadcast factors
1535
+ const uint32_t r2 = ne12 / ne02;
1536
+ const uint32_t r3 = ne13 / ne03;
1537
+
1538
+ // no work for this thread
1539
+ if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
1540
+ return;
1541
+ }
1542
+
1543
+ // block-tiling attempt
1544
+ const uint32_t blck_0 = 64;
1545
+ const uint32_t blck_1 = 64;
1546
+
1547
+ float tmp[32];
1548
+
1549
+ for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
1550
+ for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
1551
+ for (uint32_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1++) {
1552
+ const uint32_t i13 = (ir1 / (ne12 * ne1));
1553
+ const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
1554
+ const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
1555
+
1556
+ // broadcast src0 into src1
1557
+ const uint32_t i03 = i13 / r3;
1558
+ const uint32_t i02 = i12 / r2;
1559
+
1560
+ const uint32_t i1 = i11;
1561
+ const uint32_t i2 = i12;
1562
+ const uint32_t i3 = i13;
1563
+
1564
+ const uint8_t * restrict src0_row = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
1565
+ const uint8_t * restrict src1_col =
1566
+ (const uint8_t *) src1->data + (i11 + i12 * ne11 + i13 * ne12 * ne11) * src1_row_size;
1567
+ float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
1568
+
1569
+ for (uint32_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0++) {
1570
+ vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row + ir0 * src0_row_size, src1_col);
1571
+ }
1572
+
1573
+ hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0);
1574
+ }
1575
+ }
1576
+ }
1577
+
1578
+ t2 = HAP_perf_get_qtimer_count();
1579
+
1580
+ FARF(HIGH, "matmul-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
1581
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0],
1582
+ src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
1583
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1584
+ }
1585
+
1586
+ // *** dynamic quant
1587
+
1588
+ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
1589
+ assert((unsigned long) x % 128 == 0);
1590
+ assert((unsigned long) y_q % 128 == 0);
1591
+
1592
+ HVX_Vector * vx = (HVX_Vector *) x;
1593
+
1594
+ // Load and convert into QF32
1595
+ HVX_Vector zero = Q6_V_vsplat_R(0);
1596
+ HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
1597
+ HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
1598
+ HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
1599
+ HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
1600
+
1601
+ // Convert into fp16
1602
+ HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
1603
+ HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
1604
+
1605
+ // Compute max and scale
1606
+ HVX_Vector vmax_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
1607
+ vmax_hf = hvx_vec_reduce_max2_fp16(hvx_vec_abs_fp16(vx23_hf), vmax_hf);
1608
+
1609
+ // Replicate first fp16 scale across all lanes
1610
+ HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
1611
+ vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl);
1612
+
1613
+ HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
1614
+ HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16);
1615
+
1616
+ *(HVX_UVector *) y_d = vd_hf;
1617
+
1618
+ // Divide input by the scale
1619
+ HVX_Vector vd_inv_hf = hvx_vec_inverse_fp16(vd_hf);
1620
+ vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf));
1621
+ vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf));
1622
+
1623
+ // Convert to int8
1624
+ HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
1625
+ HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
1626
+ HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
1627
+
1628
+ *(HVX_Vector *) y_q = vx_i8;
1629
+ }
1630
+
1631
+ // Overrides input x
1632
+ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
1633
+ assert(k % 32 == 0);
1634
+ const uint32_t qk = QK_Q8_0x4x2;
1635
+ const uint32_t nb = (k + qk - 1) / qk;
1636
+
1637
+ const uint32_t qrow_size = k; // int8
1638
+
1639
+ const uint32_t dblk_size = 8 * 2; // 8x __fp16
1640
+ const uint32_t qblk_size = QK_Q8_0x4x2; // int8
1641
+
1642
+ uint8_t * restrict y_q = (y + 0); // quants first
1643
+ uint8_t * restrict y_d = (y + qrow_size); // then scales
1644
+
1645
+ // Temp scales override input since we're working off of the aligned temp buffer in VTCM
1646
+ uint8_t * restrict t_d = (uint8_t *) x;
1647
+
1648
+ for (uint32_t i = 0; i < nb; i++) {
1649
+ quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
1650
+ t_d + (i * 2 + 0) * dblk_size / 2);
1651
+ quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
1652
+ t_d + (i * 2 + 1) * dblk_size / 2);
1653
+ }
1654
+
1655
+ // now copy the scales into final location
1656
+ hvx_copy_fp16_ua(y_d, t_d, nb * 8);
1657
+ }
1658
+
1659
+ static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
1660
+ uint8_t * restrict dst,
1661
+ struct htp_spad * spad,
1662
+ uint32_t nth,
1663
+ uint32_t ith,
1664
+ uint32_t nrows_per_thread) {
1665
+ uint64_t t1 = HAP_perf_get_qtimer_count();
1666
+
1667
+ const uint32_t ne0 = src->ne[0];
1668
+ const uint32_t ne1 = src->ne[1];
1669
+ const uint32_t ne2 = src->ne[2];
1670
+ const uint32_t ne3 = src->ne[3];
1671
+
1672
+ const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
1673
+
1674
+ const uint32_t ir_first = nrows_per_thread * ith; // first row
1675
+ const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
1676
+
1677
+ const size_t src_row_size = src->nb[1];
1678
+ const size_t dst_row_size = q8x4x2_row_size(ne0);
1679
+
1680
+ uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first);
1681
+ uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first);
1682
+ uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith);
1683
+
1684
+ const size_t src_row_size_padded = htp_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
1685
+ memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding
1686
+
1687
+ for (uint32_t i = ir_first; i < ir_last; ++i) {
1688
+ htp_l2fetch(src_data, 2, src_row_size, src_row_size);
1689
+ hvx_copy_fp32_aa(tmp_data, src_data, ne0);
1690
+
1691
+ // FARF(HIGH, "quantize-q8x4-row: %u\n", i);
1692
+ quantize_row_fp32_q8x4x2((float *) tmp_data, dst_data, ne0);
1693
+ dst_data += dst_row_size;
1694
+ src_data += src_row_size;
1695
+ }
1696
+
1697
+ uint64_t t2 = HAP_perf_get_qtimer_count();
1698
+
1699
+ FARF(HIGH, "quantize-fp32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
1700
+ ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
1701
+ }
1702
+
1703
+ static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) {
1704
+ struct htp_ops_context * octx = data;
1705
+ quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread);
1706
+ }
1707
+
1708
+ // ** matmul callbacks for worker_pool
1709
+
1710
+ static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1711
+ struct htp_ops_context * octx = data;
1712
+
1713
+ struct htp_matmul_type mt;
1714
+ mt.type = "q4x4x2-q8x4x2";
1715
+ mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
1716
+ mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
1717
+
1718
+ matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
1719
+ octx->src0_nrows_per_thread, octx->ctx->dma[i]);
1720
+ }
1721
+
1722
+ static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1723
+ struct htp_ops_context * octx = data;
1724
+
1725
+ struct htp_matmul_type mt;
1726
+ mt.type = "q4x4x2-q8x4x2";
1727
+ mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
1728
+ mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
1729
+
1730
+ matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
1731
+ octx->src0_nrows_per_thread, octx->ctx->dma[i]);
1732
+ }
1733
+
1734
+ static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1735
+ struct htp_ops_context * octx = data;
1736
+
1737
+ struct htp_matmul_type mt;
1738
+ mt.type = "q8x4x2-q8x4x2";
1739
+ mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
1740
+ mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
1741
+
1742
+ matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
1743
+ octx->src0_nrows_per_thread, octx->ctx->dma[i]);
1744
+ }
1745
+
1746
+ static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1747
+ struct htp_ops_context * octx = data;
1748
+
1749
+ struct htp_matmul_type mt;
1750
+ mt.type = "q8x4x2-q8x4x2";
1751
+ mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
1752
+ mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
1753
+
1754
+ matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
1755
+ octx->src0_nrows_per_thread, octx->ctx->dma[i]);
1756
+ }
1757
+
1758
+ static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1759
+ struct htp_ops_context * octx = data;
1760
+
1761
+ struct htp_matmul_type mt;
1762
+ mt.type = "mxfp4x4x2-q8x4x2";
1763
+ mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
1764
+ mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
1765
+
1766
+ matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
1767
+ octx->src0_nrows_per_thread, octx->ctx->dma[i]);
1768
+ }
1769
+
1770
+ static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1771
+ struct htp_ops_context * octx = data;
1772
+
1773
+ struct htp_matmul_type mt;
1774
+ mt.type = "mxfp4x4x2-q8x4x2";
1775
+ mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
1776
+ mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
1777
+
1778
+ matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
1779
+ octx->src0_nrows_per_thread, octx->ctx->dma[i]);
1780
+ }
1781
+
1782
+ static void htp_matmul_f16_f32(unsigned int n, unsigned int i, void * data) {
1783
+ struct htp_ops_context * octx = data;
1784
+ matmul_f16_f32(&octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
1785
+ octx->src0_nrows_per_thread, octx->ctx->dma[i]);
1786
+ }
1787
+
1788
+ // ** matmul-id callbacks for worker_pool
1789
+
1790
+ static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1791
+ struct htp_ops_context * octx = data;
1792
+
1793
+ struct htp_matmul_type mt;
1794
+ mt.type = "q4x4x2-q8x4x2";
1795
+ mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
1796
+ mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
1797
+
1798
+ matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
1799
+ &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
1800
+ }
1801
+
1802
+ static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1803
+ struct htp_ops_context * octx = data;
1804
+
1805
+ struct htp_matmul_type mt;
1806
+ mt.type = "q4x4x2-q8x4x2";
1807
+ mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
1808
+ mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
1809
+
1810
+ matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
1811
+ &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
1812
+ }
1813
+
1814
+ static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1815
+ struct htp_ops_context * octx = data;
1816
+
1817
+ struct htp_matmul_type mt;
1818
+ mt.type = "q8x4x2-q8x4x2";
1819
+ mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
1820
+ mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
1821
+
1822
+ matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
1823
+ &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
1824
+ }
1825
+
1826
+ static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1827
+ struct htp_ops_context * octx = data;
1828
+
1829
+ struct htp_matmul_type mt;
1830
+ mt.type = "q8x4x2-q8x4x2";
1831
+ mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
1832
+ mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
1833
+
1834
+ matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
1835
+ &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
1836
+ }
1837
+
1838
+ static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1839
+ struct htp_ops_context * octx = data;
1840
+
1841
+ struct htp_matmul_type mt;
1842
+ mt.type = "mxfp4x4x2-q8x4x2";
1843
+ mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
1844
+ mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
1845
+
1846
+ matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
1847
+ &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
1848
+ }
1849
+
1850
+ static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
1851
+ struct htp_ops_context * octx = data;
1852
+
1853
+ struct htp_matmul_type mt;
1854
+ mt.type = "mxfp4x4x2-q8x4x2";
1855
+ mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
1856
+ mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
1857
+
1858
+ matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
1859
+ &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
1860
+ }
1861
+
1862
+ // ** main matmul entry point
1863
+
1864
+ int op_matmul(struct htp_ops_context * octx) {
1865
+ const struct htp_tensor * src0 = &octx->src0;
1866
+ const struct htp_tensor * src1 = &octx->src1;
1867
+ struct htp_tensor * dst = &octx->dst;
1868
+
1869
+ htp_matmul_preamble;
1870
+
1871
+ const char * op_type;
1872
+
1873
+ const uint32_t src0_nrows = ne01 * ne02 * ne03;
1874
+ const uint32_t src1_nrows = ne11 * ne12 * ne13;
1875
+
1876
+ const size_t src0_row_size = nb01;
1877
+ const size_t dst_row_size = nb1;
1878
+ size_t src1_row_size = nb11;
1879
+
1880
+ const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
1881
+ size_t src1_row_size_padded;
1882
+
1883
+ worker_callback_t quant_job_func;
1884
+ worker_callback_t matmul_job_func;
1885
+
1886
+ bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
1887
+
1888
+ switch (src0->type) {
1889
+ case HTP_TYPE_Q4_0:
1890
+ op_type = "q4x4x2-fp32";
1891
+ quant_job_func = htp_quantize_fp32_q8x4x2;
1892
+ if (src1_nrows > 1) {
1893
+ matmul_job_func = htp_matmul_q4x4x2_q8x4x2;
1894
+ } else {
1895
+ matmul_job_func = htp_matvec_q4x4x2_q8x4x2;
1896
+ }
1897
+
1898
+ src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
1899
+
1900
+ // Entire src1 tensor is placed into the VTCM
1901
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
1902
+
1903
+ octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
1904
+ octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
1905
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
1906
+
1907
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
1908
+ src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
1909
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
1910
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
1911
+ }
1912
+
1913
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
1914
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
1915
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
1916
+ break;
1917
+
1918
+ case HTP_TYPE_Q8_0:
1919
+ op_type = "q8x4x2-fp32";
1920
+ quant_job_func = htp_quantize_fp32_q8x4x2;
1921
+ if (src1_nrows > 1) {
1922
+ matmul_job_func = htp_matmul_q8x4x2_q8x4x2;
1923
+ } else {
1924
+ matmul_job_func = htp_matvec_q8x4x2_q8x4x2;
1925
+ }
1926
+
1927
+ src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
1928
+
1929
+ // Entire src1 tensor is placed into the VTCM
1930
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
1931
+
1932
+ octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
1933
+ octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
1934
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
1935
+
1936
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
1937
+ src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
1938
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
1939
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
1940
+ }
1941
+
1942
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
1943
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
1944
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
1945
+ break;
1946
+
1947
+ case HTP_TYPE_MXFP4:
1948
+ op_type = "mxfp4x4x2-f32";
1949
+ quant_job_func = htp_quantize_fp32_q8x4x2;
1950
+ if (src1_nrows > 1) {
1951
+ matmul_job_func = htp_matmul_mxfp4x4x2_q8x4x2;
1952
+ } else {
1953
+ matmul_job_func = htp_matvec_mxfp4x4x2_q8x4x2;
1954
+ }
1955
+
1956
+ src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
1957
+
1958
+ // Entire src1 tensor is placed into the VTCM
1959
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
1960
+
1961
+ octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
1962
+ octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
1963
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
1964
+
1965
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
1966
+ src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
1967
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
1968
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
1969
+ }
1970
+
1971
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
1972
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
1973
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
1974
+ break;
1975
+
1976
+ case HTP_TYPE_F16:
1977
+ op_type = "f16-f32";
1978
+ quant_job_func = NULL; // htp_quantize_f32_f16;
1979
+ matmul_job_func = htp_matmul_f16_f32;
1980
+
1981
+ // For all tensors we allocate N rows per thread, padded to HVX vector size
1982
+ octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
1983
+ octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size, 256);
1984
+ octx->src1_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC1_NROWS * src1_row_size, 256);
1985
+
1986
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
1987
+ octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
1988
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
1989
+
1990
+ need_quant = false;
1991
+ break;
1992
+
1993
+ default:
1994
+ return HTP_STATUS_NO_SUPPORT;
1995
+ }
1996
+
1997
+ // VTCM scratchpads for all tensors
1998
+ size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
1999
+
2000
+ FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", op_type,
2001
+ octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size);
2002
+
2003
+ FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, src0->ne[0],
2004
+ src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
2005
+ dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data);
2006
+
2007
+ // Make sure the reserved vtcm size is sufficient
2008
+ if (octx->ctx->vtcm_size < spad_size) {
2009
+ FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
2010
+ octx->ctx->vtcm_size, spad_size);
2011
+ return HTP_STATUS_VTCM_TOO_SMALL;
2012
+ }
2013
+
2014
+ octx->src0_spad.data = octx->ctx->vtcm_base;
2015
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
2016
+ octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
2017
+
2018
+ octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
2019
+ octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even
2020
+
2021
+ if (need_quant) {
2022
+ // Run quant jobs
2023
+ const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
2024
+ octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2025
+ worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
2026
+ }
2027
+
2028
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
2029
+ // Run matmul jobs
2030
+ const uint32_t n_matmul_jobs = octx->n_threads;
2031
+ worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, octx, n_matmul_jobs);
2032
+ }
2033
+
2034
+ return HTP_STATUS_OK;
2035
+ }
2036
+
2037
+ // ** main matmul-id entry point
2038
+
2039
+ int op_matmul_id(struct htp_ops_context * octx) {
2040
+ const struct htp_tensor * src0 = &octx->src0;
2041
+ const struct htp_tensor * src1 = &octx->src1;
2042
+ const struct htp_tensor * ids = &octx->src2;
2043
+ struct htp_tensor * dst = &octx->dst;
2044
+
2045
+ htp_matmul_preamble;
2046
+
2047
+ const char * op_type;
2048
+
2049
+ worker_callback_t quant_job_func;
2050
+ worker_callback_t matmul_id_job_func;
2051
+
2052
+ const size_t src0_row_size = nb01;
2053
+ const size_t dst_row_size = nb1;
2054
+
2055
+ const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
2056
+
2057
+ const uint32_t src0_nrows = ne01; // per expert
2058
+ const uint32_t src1_nrows = ne11 * ne12 * ne13;
2059
+
2060
+ size_t src1_row_size;
2061
+ size_t src1_row_size_padded;
2062
+
2063
+ // row groups
2064
+ const int n_ids = ids->ne[0]; // n_expert_used
2065
+ const int n_as = ne02; // n_expert
2066
+
2067
+ size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
2068
+ size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
2069
+
2070
+ switch (src0->type) {
2071
+ case HTP_TYPE_Q4_0:
2072
+ op_type = "q4x2x2-f32";
2073
+ quant_job_func = htp_quantize_fp32_q8x4x2;
2074
+ src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2075
+ if (src1_nrows > 1) {
2076
+ matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2;
2077
+ } else {
2078
+ matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2;
2079
+ }
2080
+
2081
+ // Entire src1 tensor is placed into the VTCM
2082
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
2083
+ octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
2084
+ octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2085
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2086
+ octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
2087
+
2088
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
2089
+ src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2090
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2091
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
2092
+ }
2093
+
2094
+ octx->src2_spad.size = octx->src2_spad.size_per_thread;
2095
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
2096
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2097
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2098
+ break;
2099
+
2100
+ case HTP_TYPE_Q8_0:
2101
+ op_type = "q8x2x2-f32";
2102
+ quant_job_func = htp_quantize_fp32_q8x4x2;
2103
+ src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2104
+ if (src1_nrows > 1) {
2105
+ matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2;
2106
+ } else {
2107
+ matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2;
2108
+ }
2109
+
2110
+ // Entire src1 tensor is placed into the VTCM
2111
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
2112
+ octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
2113
+ octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2114
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2115
+ octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
2116
+
2117
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
2118
+ src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2119
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2120
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
2121
+ }
2122
+
2123
+ octx->src2_spad.size = octx->src2_spad.size_per_thread;
2124
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
2125
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2126
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2127
+ break;
2128
+
2129
+ case HTP_TYPE_MXFP4:
2130
+ op_type = "mxfp4x2x2-f32";
2131
+ quant_job_func = htp_quantize_fp32_q8x4x2;
2132
+ src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
2133
+ if (src1_nrows > 1) {
2134
+ matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2;
2135
+ } else {
2136
+ matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2;
2137
+ }
2138
+
2139
+ // Entire src1 tensor is placed into the VTCM
2140
+ // For other tensors we allocate N rows per thread, padded to HVX vector size
2141
+ octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
2142
+ octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
2143
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
2144
+ octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
2145
+
2146
+ // src0 spad is also used in dynamic quantizer to store padded src1 rows
2147
+ src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
2148
+ if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
2149
+ octx->src0_spad.size_per_thread = src1_row_size_padded;
2150
+ }
2151
+
2152
+ octx->src2_spad.size = octx->src2_spad.size_per_thread;
2153
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
2154
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
2155
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
2156
+ break;
2157
+
2158
+ default:
2159
+ return HTP_STATUS_NO_SUPPORT;
2160
+ }
2161
+
2162
+ size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
2163
+
2164
+ FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", op_type,
2165
+ octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size);
2166
+
2167
+ FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type,
2168
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
2169
+ ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data,
2170
+ src1->data, dst->data);
2171
+
2172
+ // Make sure the reserved vtcm size is sufficient
2173
+ if (octx->ctx->vtcm_size < spad_size) {
2174
+ FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
2175
+ octx->ctx->vtcm_size, spad_size);
2176
+ return HTP_STATUS_VTCM_TOO_SMALL;
2177
+ }
2178
+
2179
+ octx->src0_spad.data = octx->ctx->vtcm_base;
2180
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
2181
+ octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
2182
+ octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size;
2183
+
2184
+ octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
2185
+ octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even
2186
+
2187
+ if (src1_nrows > 1) {
2188
+ // initialize matrix_row_counts and map
2189
+ uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0;
2190
+ struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size;
2191
+
2192
+ memset(matrix_row_counts, 0, n_as * sizeof(uint32_t));
2193
+
2194
+ // group rows by src0 matrix
2195
+ for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx
2196
+ for (uint32_t id = 0; id < n_ids; ++id) { // expert idx
2197
+ const uint32_t i02 =
2198
+ *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
2199
+
2200
+ assert(i02 >= 0 && i02 < n_as);
2201
+
2202
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 };
2203
+ matrix_row_counts[i02] += 1;
2204
+ }
2205
+ }
2206
+ }
2207
+
2208
+ // Setup worker pool callbacks
2209
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
2210
+ // Run quant jobs
2211
+ const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
2212
+ octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
2213
+ worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
2214
+ }
2215
+
2216
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
2217
+ // Run matmul-id jobs
2218
+ const uint32_t n_matmul_jobs = octx->n_threads;
2219
+ worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, octx, n_matmul_jobs);
2220
+ }
2221
+
2222
+ return HTP_STATUS_OK;
2223
+ }