whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -0,0 +1,3151 @@
1
+ #include <assert.h>
2
+ #include <inttypes.h>
3
+ #include <stdio.h>
4
+ #include <stdlib.h>
5
+ #include <string.h>
6
+ #include <time.h>
7
+
8
+ #include <atomic>
9
+ #include <chrono>
10
+ #include <cstddef>
11
+ #include <mutex>
12
+ #include <stdexcept>
13
+ #include <string>
14
+
15
+ #ifdef _WIN32
16
+ # include <sal.h>
17
+ # ifndef _WINDOWS
18
+ # define _WINDOWS
19
+ # endif
20
+ #else
21
+ # include <semaphore.h>
22
+ # include <unistd.h>
23
+ #endif
24
+
25
+ #pragma clang diagnostic ignored "-Wnested-anon-types"
26
+ #pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
27
+
28
+ #include "htp-utils.h"
29
+
30
+ #include <AEEStdErr.h>
31
+ #include <dspqueue.h>
32
+ #include <rpcmem.h>
33
+
34
+ #define GGML_COMMON_IMPL_CPP
35
+ #include "ggml-backend-impl.h"
36
+ #include "ggml-common.h"
37
+ #include "ggml-hexagon.h"
38
+ #include "ggml-impl.h"
39
+ #include "ggml-quants.h"
40
+ #include "op-desc.h"
41
+ #include "htp-msg.h"
42
+ #include "htp_iface.h"
43
+
44
+ static size_t opt_ndev = 1;
45
+ static size_t opt_nhvx = 0; // use all
46
+ static int opt_arch = 0; // autodetect
47
+ static int opt_etm = 0;
48
+ static int opt_verbose = 0;
49
+ static int opt_profile = 0;
50
+ static int opt_hostbuf = 1;
51
+ static int opt_experimental = 0;
52
+
53
+ // Enable all stages by default
54
+ static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE;
55
+ static int opt_opsync = 0; // synchronous ops
56
+
57
+ #define HEX_VERBOSE(...) \
58
+ if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__)
59
+
60
+ static inline uint64_t hex_is_aligned(void * addr, uint32_t align) {
61
+ return ((size_t) addr & (align - 1)) == 0;
62
+ }
63
+
64
+ static inline size_t hex_round_up(size_t n, size_t m) {
65
+ return m * ((n + m - 1) / m);
66
+ }
67
+
68
+ static const char * status_to_str(uint32_t status) {
69
+ switch (status) {
70
+ case HTP_STATUS_OK:
71
+ return "OK";
72
+ case HTP_STATUS_NO_SUPPORT:
73
+ return "NO-SUPPORT";
74
+ case HTP_STATUS_INVAL_PARAMS:
75
+ return "INVAL-PARAMS";
76
+ case HTP_STATUS_VTCM_TOO_SMALL:
77
+ return "VTCM-TOO-SMALL";
78
+ case HTP_STATUS_INTERNAL_ERR:
79
+ return "INTERNAL-ERROR";
80
+ default:
81
+ return "UNKNOWN";
82
+ }
83
+ }
84
+
85
+ // ** debug helpers
86
+
87
+ static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const ggml_tensor * op, const uint32_t req_flags) {
88
+ if (!opt_verbose) return;
89
+
90
+ op_desc desc(op);
91
+ GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(),
92
+ ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags);
93
+ }
94
+
95
+ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) {
96
+ if (!opt_verbose) return;
97
+
98
+ op_desc desc(op);
99
+ GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(),
100
+ ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no");
101
+ }
102
+
103
+ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op,
104
+ uint32_t op_usec, uint32_t op_cycles, uint32_t op_pkts, uint64_t call_usec) {
105
+ if (!opt_profile) return;
106
+
107
+ op_desc desc(op);
108
+ GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : %s : op-usec %u op-cycles %u op-pkts %u (%f) call-usec %llu\n", sess_name.c_str(),
109
+ ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs,
110
+ op_usec, op_cycles, op_pkts, (float) op_cycles / op_pkts, (unsigned long long) call_usec);
111
+ }
112
+
113
+ // ** backend sessions
114
+
115
+ struct ggml_hexagon_session {
116
+ ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false);
117
+ ~ggml_hexagon_session() noexcept(true);
118
+
119
+ void allocate(int dev_id) noexcept(false);
120
+ void release() noexcept(true);
121
+
122
+ void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false);
123
+ void flush();
124
+
125
+ ggml_backend_buffer_type buffer_type = {};
126
+ ggml_backend_buffer_type repack_buffer_type = {};
127
+
128
+ std::string name;
129
+ remote_handle64 handle;
130
+ dspqueue_t queue;
131
+ uint32_t session_id;
132
+ uint32_t domain_id;
133
+ uint64_t queue_id;
134
+ int dev_id;
135
+ bool valid_session;
136
+ bool valid_handle;
137
+ bool valid_queue;
138
+ bool valid_iface;
139
+ std::atomic<int> op_pending;
140
+ uint32_t prof_usecs;
141
+ uint32_t prof_cycles;
142
+ uint32_t prof_pkts;
143
+ };
144
+
145
+ void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) {
146
+ // Bump pending flag (cleared in the session::flush once we get the responce)
147
+ this->op_pending++; // atomic inc
148
+
149
+ int err = dspqueue_write(this->queue,
150
+ 0, // flags - the framework will autoset this
151
+ n_bufs, // number of buffers
152
+ bufs, // buffer references
153
+ sizeof(req),
154
+ (const uint8_t *) &req, // Message
155
+ 1000000 // Timeout
156
+ );
157
+
158
+ if (err != 0) {
159
+ GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->name.c_str(), (unsigned) err);
160
+ }
161
+
162
+ if (sync) {
163
+ flush();
164
+ }
165
+ }
166
+
167
+ // Flush HTP response queue i.e wait for all outstanding requests to complete
168
+ void ggml_hexagon_session::flush() {
169
+ dspqueue_t q = this->queue;
170
+
171
+ // Repeatedly read packets from the queue until it's empty. We don't
172
+ // necessarily get a separate callback for each packet, and new packets
173
+ // may arrive while we're processing the previous one.
174
+
175
+ while (this->op_pending) {
176
+ struct htp_general_rsp rsp;
177
+ uint32_t rsp_size;
178
+ uint32_t flags;
179
+
180
+ struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
181
+ uint32_t n_bufs;
182
+
183
+ // Read response packet from queue
184
+ int err = dspqueue_read(q, &flags,
185
+ HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
186
+ &n_bufs, // Number of buffer references
187
+ bufs, // Buffer references
188
+ sizeof(rsp), // Max message length
189
+ &rsp_size, // Message length
190
+ (uint8_t *) &rsp,
191
+ 1000000); // Timeout
192
+
193
+ if (err == AEE_EEXPIRED) {
194
+ // TODO: might need to bail out if the HTP is stuck on something
195
+ continue;
196
+ }
197
+
198
+ if (err != 0) {
199
+ GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err);
200
+ }
201
+
202
+ // Basic sanity checks
203
+ if (rsp_size != sizeof(rsp)) {
204
+ GGML_ABORT("ggml-hex: dspcall : bad response (size)\n");
205
+ }
206
+
207
+ if (rsp.status != HTP_STATUS_OK) {
208
+ GGML_LOG_ERROR("ggml-hex: dspcall : dsp-rsp: %s\n", status_to_str(rsp.status));
209
+ // TODO: handle errors
210
+ }
211
+
212
+ // TODO: update profiling implementation, currently only works for opt_opsync mode
213
+ this->prof_usecs = rsp.prof_usecs;
214
+ this->prof_cycles = rsp.prof_cycles;
215
+ this->prof_pkts = rsp.prof_pkts;
216
+
217
+ this->op_pending--; // atomic dec
218
+ }
219
+ }
220
+
221
+ // ** backend buffers
222
+
223
+ struct ggml_backend_hexagon_buffer_type_context {
224
+ ggml_backend_hexagon_buffer_type_context(const std::string & name, ggml_hexagon_session * sess) {
225
+ this->sess = sess;
226
+ this->name = name;
227
+ }
228
+
229
+ ggml_hexagon_session * sess;
230
+ std::string name;
231
+ };
232
+
233
+ struct ggml_backend_hexagon_buffer_context {
234
+ bool mmap_to(ggml_hexagon_session * s) {
235
+ HEX_VERBOSE("ggml-hex: %s mmaping buffer: base %p domain-id %d session-id %d size %zu fd %d repack %d\n",
236
+ s->name.c_str(), (void *) this->base, s->domain_id, s->session_id, this->size, this->fd,
237
+ (int) this->repack);
238
+
239
+ int err = fastrpc_mmap(s->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD);
240
+ if (err != 0) {
241
+ GGML_LOG_ERROR("ggml-hex: buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n",
242
+ s->domain_id, this->size, this->fd, (unsigned) err);
243
+ return false;
244
+ }
245
+
246
+ return true;
247
+ }
248
+
249
+ bool mmap() {
250
+ if (this->mapped) {
251
+ return true;
252
+ }
253
+ if (!mmap_to(this->sess)) {
254
+ return false;
255
+ }
256
+ this->mapped = true;
257
+ return true;
258
+ }
259
+
260
+ void munmap() {
261
+ if (!this->mapped) {
262
+ return;
263
+ }
264
+
265
+ fastrpc_munmap(this->sess->domain_id, this->fd, this->base, this->size);
266
+ this->mapped = false;
267
+ }
268
+
269
+ ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) {
270
+ size += 4 * 1024; // extra page for padding
271
+
272
+ if (rpcmem_alloc2) {
273
+ this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
274
+ } else {
275
+ GGML_LOG_INFO("ggml-hex: %s rpcmem_alloc2 not found, falling back to rpcmem_alloc\n", sess->name.c_str());
276
+ this->base = (uint8_t *) rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
277
+ }
278
+
279
+ if (!this->base) {
280
+ GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size);
281
+ throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)");
282
+ }
283
+
284
+ this->fd = rpcmem_to_fd(this->base);
285
+ if (this->fd < 0) {
286
+ GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->name.c_str(), (void *) this->base);
287
+ rpcmem_free(this->base);
288
+ this->base = NULL;
289
+ throw std::runtime_error("ggml-hex: rpcmem_to_fd failed (see log for details)");
290
+ }
291
+
292
+ HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d repack %d\n", sess->name.c_str(),
293
+ (void *) this->base, size, this->fd, (int) repack);
294
+
295
+ this->sess = sess;
296
+ this->size = size;
297
+ this->mapped = false;
298
+ this->repack = repack;
299
+ }
300
+
301
+ ~ggml_backend_hexagon_buffer_context() {
302
+ munmap();
303
+ if (this->base) {
304
+ rpcmem_free(this->base);
305
+ this->base = NULL;
306
+ }
307
+ }
308
+
309
+ ggml_hexagon_session * sess; // primary session
310
+ uint8_t * base;
311
+ size_t size;
312
+ int fd;
313
+ bool mapped; // mmap is done
314
+ bool repack; // repacked buffer
315
+ };
316
+
317
+ static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_buffer_t buffer) {
318
+ return static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer->buft->context)->sess;
319
+ }
320
+
321
+ static void ggml_backend_hexagon_buffer_free_buffer(ggml_backend_buffer_t buffer) {
322
+ auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
323
+ delete ctx;
324
+ }
325
+
326
+ static void * ggml_backend_hexagon_buffer_get_base(ggml_backend_buffer_t buffer) {
327
+ auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
328
+ return ctx->base;
329
+ }
330
+
331
+ static enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
332
+ auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context);
333
+ auto sess = ctx->sess;
334
+
335
+ HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d repack %d\n", sess->name.c_str(),
336
+ tensor->name, (void *) ctx->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage,
337
+ (int) ctx->repack);
338
+
339
+ if (tensor->view_src != NULL && tensor->view_offs == 0) {
340
+ ; // nothing to do for the view
341
+ } else {
342
+ if (!ctx->mapped) {
343
+ ctx->mmap();
344
+ }
345
+ }
346
+ return GGML_STATUS_SUCCESS;
347
+ }
348
+
349
+ // ======== Q4x4x2 ====================
350
+ struct x2_q4 {
351
+ int v[2];
352
+ };
353
+
354
+ static x2_q4 unpack_q4(uint8_t v) {
355
+ x2_q4 x = { (int) (v & 0x0f) - 8, (int) (v >> 4) - 8 };
356
+ return x;
357
+ }
358
+
359
+ static void dump_block_q4_0(const block_q4_0 * b, int i) {
360
+ HEX_VERBOSE("ggml-hex: repack q4_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_q4(b->qs[0]).v[0],
361
+ unpack_q4(b->qs[1]).v[0], unpack_q4(b->qs[2]).v[0], unpack_q4(b->qs[3]).v[0], unpack_q4(b->qs[12]).v[1],
362
+ unpack_q4(b->qs[13]).v[1], unpack_q4(b->qs[14]).v[1], unpack_q4(b->qs[15]).v[1],
363
+ GGML_FP16_TO_FP32(b->d));
364
+ }
365
+
366
+ static void dump_packed_block_q4x4x2(const uint8_t * v, unsigned int i, size_t k) {
367
+ static const int qk = QK_Q4_0x4x2;
368
+ const int dblk_size = 8 * 2; // 8x __fp16
369
+ const int qblk_size = qk / 2; // int4
370
+ const int qrow_size = k / 2; // int4 (not padded)
371
+
372
+ const uint8_t * v_q = v + 0; // quants first
373
+ const uint8_t * v_d = v + qrow_size; // then scales
374
+
375
+ const uint8_t * q = v_q + i * qblk_size;
376
+ const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size);
377
+
378
+ HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
379
+ unpack_q4(q[0]).v[0], unpack_q4(q[1]).v[0], unpack_q4(q[2]).v[0], unpack_q4(q[3]).v[0],
380
+ unpack_q4(q[60]).v[0], unpack_q4(q[61]).v[0], unpack_q4(q[62]).v[0], unpack_q4(q[63]).v[0],
381
+ unpack_q4(q[124]).v[0], unpack_q4(q[125]).v[0], unpack_q4(q[126]).v[0], unpack_q4(q[127]).v[0],
382
+ GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3]));
383
+
384
+ HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
385
+ i + 1, unpack_q4(q[0]).v[1], unpack_q4(q[1]).v[1], unpack_q4(q[2]).v[1], unpack_q4(q[3]).v[1],
386
+ unpack_q4(q[60]).v[1], unpack_q4(q[61]).v[1], unpack_q4(q[62]).v[1], unpack_q4(q[63]).v[1],
387
+ unpack_q4(q[124]).v[1], unpack_q4(q[125]).v[1], unpack_q4(q[126]).v[1], unpack_q4(q[127]).v[1],
388
+ GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7]));
389
+ }
390
+
391
+ static void unpack_q4_0_quants(uint8_t * qs, const block_q4_0 * x, unsigned int bi) {
392
+ static const int qk = QK4_0;
393
+
394
+ for (unsigned int i = 0; i < qk / 2; ++i) {
395
+ const int x0 = (x->qs[i] & 0x0F);
396
+ const int x1 = (x->qs[i] >> 4);
397
+ qs[bi * qk + i + 0] = x0;
398
+ qs[bi * qk + i + qk / 2] = x1;
399
+ }
400
+ }
401
+
402
+ static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi) {
403
+ static const int qk = QK4_0;
404
+
405
+ for (unsigned int i = 0; i < qk / 2; ++i) {
406
+ const uint8_t x0 = qs[bi * qk + i + 0];
407
+ const uint8_t x1 = qs[bi * qk + i + qk / 2];
408
+ x->qs[i] = x0 | (x1 << 4);
409
+ }
410
+ }
411
+
412
+ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
413
+ static const int qk = QK_Q4_0x4x2;
414
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
415
+
416
+ const int dblk_size = 8 * 2; // 8x __fp16
417
+ const int qblk_size = qk / 2; // int4
418
+ const int qrow_size = k / 2; // int4 (not padded to blocks)
419
+
420
+ uint8_t * y_q = y + 0; // quants first
421
+ uint8_t * y_d = y + qrow_size; // then scales
422
+
423
+ if (opt_verbose > 2) {
424
+ for (int i = 0; i < nb; i++) {
425
+ dump_block_q4_0(&x[i * 8 + 0], 0);
426
+ dump_block_q4_0(&x[i * 8 + 1], 1);
427
+ dump_block_q4_0(&x[i * 8 + 2], 2);
428
+ dump_block_q4_0(&x[i * 8 + 3], 3);
429
+ dump_block_q4_0(&x[i * 8 + 4], 4);
430
+ dump_block_q4_0(&x[i * 8 + 5], 5);
431
+ dump_block_q4_0(&x[i * 8 + 6], 6);
432
+ dump_block_q4_0(&x[i * 8 + 7], 7);
433
+ }
434
+ }
435
+
436
+ // Repack the quants
437
+ for (int i = 0; i < nb; i++) {
438
+ uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
439
+ unpack_q4_0_quants(qs, &x[i * 8 + 0], 0);
440
+ unpack_q4_0_quants(qs, &x[i * 8 + 1], 1);
441
+ unpack_q4_0_quants(qs, &x[i * 8 + 2], 2);
442
+ unpack_q4_0_quants(qs, &x[i * 8 + 3], 3);
443
+ unpack_q4_0_quants(qs, &x[i * 8 + 4], 4);
444
+ unpack_q4_0_quants(qs, &x[i * 8 + 5], 5);
445
+ unpack_q4_0_quants(qs, &x[i * 8 + 6], 6);
446
+ unpack_q4_0_quants(qs, &x[i * 8 + 7], 7);
447
+
448
+ uint8_t * q = y_q + (i * qblk_size);
449
+ for (int j = 0; j < qk / 2; j++) {
450
+ q[j] = (qs[j + 128] << 4) | qs[j];
451
+ }
452
+ }
453
+
454
+ // Repack the scales
455
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
456
+ // the last block is truncated and overriden by the scales.
457
+ for (int i = 0; i < nb; i++) {
458
+ // Repack the scales
459
+ ggml_half * d = (ggml_half *) (y_d + i * dblk_size);
460
+ d[0] = x[i * 8 + 0].d;
461
+ d[1] = x[i * 8 + 1].d;
462
+ d[2] = x[i * 8 + 2].d;
463
+ d[3] = x[i * 8 + 3].d;
464
+ d[4] = x[i * 8 + 4].d;
465
+ d[5] = x[i * 8 + 5].d;
466
+ d[6] = x[i * 8 + 6].d;
467
+ d[7] = x[i * 8 + 7].d;
468
+ }
469
+
470
+ if (opt_verbose > 1) {
471
+ for (int i = 0; i < nb; i++) {
472
+ dump_packed_block_q4x4x2(y, i, k);
473
+ }
474
+ }
475
+ }
476
+
477
+ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {
478
+ static const int qk = QK_Q4_0x4x2;
479
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
480
+
481
+ const int dblk_size = 8 * 2; // 8x __fp16
482
+ const int qblk_size = qk / 2; // int4
483
+ const int qrow_size = k / 2; // int4 (not padded to blocks)
484
+
485
+ const uint8_t * y_q = y + 0; // quants first
486
+ const uint8_t * y_d = y + qrow_size; // then scales
487
+
488
+ if (opt_verbose > 1) {
489
+ for (int i = 0; i < nb; i++) {
490
+ dump_packed_block_q4x4x2(y, i, k);
491
+ }
492
+ }
493
+
494
+ // Unpack the quants
495
+ for (int i = 0; i < nb; i++) {
496
+ uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
497
+
498
+ const uint8_t * q = y_q + (i * qblk_size);
499
+ for (int j = 0; j < qk / 2; j++) {
500
+ qs[j] = q[j] & 0xf;
501
+ qs[j + 128] = q[j] >> 4;
502
+ }
503
+
504
+ pack_q4_0_quants(&x[i * 8 + 0], qs, 0);
505
+ pack_q4_0_quants(&x[i * 8 + 1], qs, 1);
506
+ pack_q4_0_quants(&x[i * 8 + 2], qs, 2);
507
+ pack_q4_0_quants(&x[i * 8 + 3], qs, 3);
508
+ pack_q4_0_quants(&x[i * 8 + 4], qs, 4);
509
+ pack_q4_0_quants(&x[i * 8 + 5], qs, 5);
510
+ pack_q4_0_quants(&x[i * 8 + 6], qs, 6);
511
+ pack_q4_0_quants(&x[i * 8 + 7], qs, 7);
512
+ }
513
+
514
+ // Repack the scales
515
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
516
+ // the last block is truncated and overriden by the scales.
517
+ for (int i = 0; i < nb; i++) {
518
+ // Unpack the scales
519
+ const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);
520
+ x[i * 8 + 0].d = d[0];
521
+ x[i * 8 + 1].d = d[1];
522
+ x[i * 8 + 2].d = d[2];
523
+ x[i * 8 + 3].d = d[3];
524
+ x[i * 8 + 4].d = d[4];
525
+ x[i * 8 + 5].d = d[5];
526
+ x[i * 8 + 6].d = d[6];
527
+ x[i * 8 + 7].d = d[7];
528
+ }
529
+
530
+ if (opt_verbose > 2) {
531
+ for (int i = 0; i < nb; i++) {
532
+ dump_block_q4_0(&x[i * 8 + 0], 0);
533
+ dump_block_q4_0(&x[i * 8 + 1], 1);
534
+ dump_block_q4_0(&x[i * 8 + 2], 2);
535
+ dump_block_q4_0(&x[i * 8 + 3], 3);
536
+ dump_block_q4_0(&x[i * 8 + 4], 4);
537
+ dump_block_q4_0(&x[i * 8 + 5], 5);
538
+ dump_block_q4_0(&x[i * 8 + 6], 6);
539
+ dump_block_q4_0(&x[i * 8 + 7], 7);
540
+ }
541
+ }
542
+ }
543
+
544
+ static void init_row_q4x4x2(block_q4_0 * x, int64_t k) {
545
+ static const int qk = QK_Q4_0x4x2;
546
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
547
+
548
+ // Init the quants such that they unpack into zeros
549
+ uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
550
+ memset(qs, 8, sizeof(qs));
551
+
552
+ for (int i = 0; i < nb; i++) {
553
+ pack_q4_0_quants(&x[i * 8 + 0], qs, 0);
554
+ pack_q4_0_quants(&x[i * 8 + 1], qs, 1);
555
+ pack_q4_0_quants(&x[i * 8 + 2], qs, 2);
556
+ pack_q4_0_quants(&x[i * 8 + 3], qs, 3);
557
+ pack_q4_0_quants(&x[i * 8 + 4], qs, 4);
558
+ pack_q4_0_quants(&x[i * 8 + 5], qs, 5);
559
+ pack_q4_0_quants(&x[i * 8 + 6], qs, 6);
560
+ pack_q4_0_quants(&x[i * 8 + 7], qs, 7);
561
+ }
562
+
563
+ // Init the scales
564
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
565
+ // the last block is truncated and overriden by the scales.
566
+ for (int i = 0; i < nb; i++) {
567
+ // Unpack the scales
568
+ x[i * 8 + 0].d = 0;
569
+ x[i * 8 + 1].d = 0;
570
+ x[i * 8 + 2].d = 0;
571
+ x[i * 8 + 3].d = 0;
572
+ x[i * 8 + 4].d = 0;
573
+ x[i * 8 + 5].d = 0;
574
+ x[i * 8 + 6].d = 0;
575
+ x[i * 8 + 7].d = 0;
576
+ }
577
+ }
578
+
579
+ // repack q4_0 data into q4x4x2 tensor
580
+ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) {
581
+ int64_t nrows = ggml_nrows(t);
582
+
583
+ size_t row_size = ggml_row_size(t->type, t->ne[0]);
584
+ size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
585
+ size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
586
+
587
+ // Ensure we don't try to read more data than is available in the source buffer 'data'
588
+ // or write more than the tensor can hold.
589
+ const size_t total_tensor_size = (size_t)nrows * row_size;
590
+ const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
591
+
592
+ // Calculate how many full rows and how many remaining bytes we need to process.
593
+ const int64_t n_full_rows = n_bytes_to_copy / row_size;
594
+ const size_t n_rem_bytes = n_bytes_to_copy % row_size;
595
+
596
+ void * buf_pd = ggml_aligned_malloc(row_size_pd);
597
+ GGML_ASSERT(buf_pd != NULL);
598
+
599
+ void * buf_rp = ggml_aligned_malloc(row_size_rp);
600
+ GGML_ASSERT(buf_rp != NULL);
601
+
602
+ HEX_VERBOSE("ggml-hex: repack-q4_0-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
603
+ t->ne[0], nrows, row_size);
604
+
605
+ init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
606
+
607
+ // 1. Process all the full rows
608
+ for (int64_t i = 0; i < n_full_rows; i++) {
609
+ const uint8_t * src = (const uint8_t *) data + (i * row_size);
610
+ uint8_t * dst = (uint8_t *) t->data + (i * row_size);
611
+
612
+ memcpy(buf_pd, src, row_size);
613
+ repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);
614
+ memcpy(dst, buf_rp, row_size);
615
+ }
616
+
617
+ // 2. Process the final, potentially partial, row
618
+ if (n_rem_bytes > 0) {
619
+ const int64_t i = n_full_rows;
620
+ const uint8_t * src = (const uint8_t *) data + (i * row_size);
621
+ uint8_t * dst = (uint8_t *) t->data + (i * row_size);
622
+
623
+ // re-init the row because we are potentially copying a partial row
624
+ init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]);
625
+
626
+ // Copy only the remaining bytes from the source.
627
+ memcpy(buf_pd, src, n_rem_bytes);
628
+
629
+ // Repack the entire buffer
630
+ repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]);
631
+
632
+ // Write only the corresponding remaining bytes to the destination tensor.
633
+ memcpy(dst, buf_rp, n_rem_bytes);
634
+ }
635
+
636
+ ggml_aligned_free(buf_pd, row_size_pd);
637
+ ggml_aligned_free(buf_rp, row_size_rp);
638
+ }
639
+
640
+ // repack q4x4x2 tensor into q4_0 data
641
+ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) {
642
+ int64_t nrows = ggml_nrows(t);
643
+
644
+ size_t row_size = ggml_row_size(t->type, t->ne[0]);
645
+ size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad
646
+ size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
647
+
648
+ // Ensure we don't try to copy more data than the tensor actually contains.
649
+ const size_t total_tensor_size = (size_t)nrows * row_size;
650
+ const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
651
+
652
+ // Calculate how many full rows and how many remaining bytes we need to process.
653
+ const int64_t n_full_rows = n_bytes_to_copy / row_size;
654
+ const size_t n_rem_bytes = n_bytes_to_copy % row_size;
655
+
656
+ void * buf_pd = ggml_aligned_malloc(row_size_pd);
657
+ GGML_ASSERT(buf_pd != NULL);
658
+
659
+ void * buf_rp = ggml_aligned_malloc(row_size_rp);
660
+ GGML_ASSERT(buf_rp != NULL);
661
+
662
+ HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
663
+ t->ne[0], nrows, row_size);
664
+
665
+ memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
666
+
667
+ // 1. Process all the full rows
668
+ for (int64_t i = 0; i < n_full_rows; i++) {
669
+ const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
670
+ uint8_t * dst = (uint8_t *) data + (i * row_size);
671
+
672
+ memcpy(buf_pd, src, row_size);
673
+ unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
674
+ memcpy(dst, buf_rp, row_size);
675
+ }
676
+
677
+ // 2. Process the final, potentially partial, row
678
+ if (n_rem_bytes > 0) {
679
+ const int64_t i = n_full_rows;
680
+ const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
681
+ uint8_t * dst = (uint8_t *) data + (i * row_size);
682
+
683
+ // We still need to read and unpack the entire source row because quantization is block-based.
684
+ memcpy(buf_pd, src, row_size);
685
+ unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
686
+
687
+ // But we only copy the remaining number of bytes to the destination.
688
+ memcpy(dst, buf_rp, n_rem_bytes);
689
+ }
690
+
691
+ ggml_aligned_free(buf_pd, row_size_pd);
692
+ ggml_aligned_free(buf_rp, row_size_rp);
693
+ }
694
+
695
+ // ======== Q8x4x2 ====================
696
+ static void dump_block_q8_0(const block_q8_0 * b, int i) {
697
+ HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2],
698
+ b->qs[3], b->qs[28], b->qs[29], b->qs[30], b->qs[31], GGML_FP16_TO_FP32(b->d));
699
+ }
700
+
701
+ static void dump_packed_block_q8x4x2(const uint8_t * v, unsigned int i, size_t k) {
702
+ static const int qk = QK_Q8_0x4x2;
703
+ const int dblk_size = 8 * 2; // 8x __fp16
704
+ const int qblk_size = qk; // int8
705
+ const int qrow_size = k; // int8 (not padded)
706
+
707
+ const uint8_t * v_q = v + 0; // quants first
708
+ const uint8_t * v_d = v + qrow_size; // then scales
709
+
710
+ const uint8_t * q = v_q + i * qblk_size;
711
+ const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size);
712
+
713
+ HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
714
+ q[0], q[1], q[2], q[3], q[60], q[61], q[62], q[63], q[124], q[125], q[126], q[127],
715
+ GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3]));
716
+
717
+ HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
718
+ i + 1, q[128], q[129], q[130], q[131], q[192], q[193], q[194], q[195], q[252], q[253], q[254], q[255],
719
+ GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7]));
720
+ }
721
+
722
+ static void unpack_q8_0_quants(uint8_t * qs, const block_q8_0 * x, unsigned int bi) {
723
+ static const int qk = QK8_0;
724
+
725
+ for (unsigned int i = 0; i < qk; ++i) {
726
+ qs[bi * qk + i] = x->qs[i];
727
+ }
728
+ }
729
+
730
+ static void pack_q8_0_quants(block_q8_0 * x, const uint8_t * qs, unsigned int bi) {
731
+ static const int qk = QK8_0;
732
+
733
+ for (unsigned int i = 0; i < qk; ++i) {
734
+ x->qs[i] = qs[bi * qk + i];
735
+ }
736
+ }
737
+
738
+ static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) {
739
+ static const int qk = QK_Q8_0x4x2;
740
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
741
+
742
+ const int dblk_size = 8 * 2; // 8x __fp16
743
+ const int qblk_size = qk; // int8
744
+ const int qrow_size = k; // int8 (not padded to blocks)
745
+
746
+ uint8_t * y_q = y + 0; // quants first
747
+ uint8_t * y_d = y + qrow_size; // then scales
748
+
749
+ if (opt_verbose > 2) {
750
+ for (int i = 0; i < nb; i++) {
751
+ dump_block_q8_0(&x[i * 8 + 0], 0);
752
+ dump_block_q8_0(&x[i * 8 + 1], 1);
753
+ dump_block_q8_0(&x[i * 8 + 2], 2);
754
+ dump_block_q8_0(&x[i * 8 + 3], 3);
755
+ dump_block_q8_0(&x[i * 8 + 4], 4);
756
+ dump_block_q8_0(&x[i * 8 + 5], 5);
757
+ dump_block_q8_0(&x[i * 8 + 6], 6);
758
+ dump_block_q8_0(&x[i * 8 + 7], 7);
759
+ }
760
+ }
761
+
762
+ // Repack the quants
763
+ for (int i = 0; i < nb; i++) {
764
+ uint8_t qs[QK_Q8_0x4x2]; // unpacked quants
765
+
766
+ unpack_q8_0_quants(qs, &x[i * 8 + 0], 0);
767
+ unpack_q8_0_quants(qs, &x[i * 8 + 1], 1);
768
+ unpack_q8_0_quants(qs, &x[i * 8 + 2], 2);
769
+ unpack_q8_0_quants(qs, &x[i * 8 + 3], 3);
770
+ unpack_q8_0_quants(qs, &x[i * 8 + 4], 4);
771
+ unpack_q8_0_quants(qs, &x[i * 8 + 5], 5);
772
+ unpack_q8_0_quants(qs, &x[i * 8 + 6], 6);
773
+ unpack_q8_0_quants(qs, &x[i * 8 + 7], 7);
774
+
775
+ uint8_t * q = y_q + (i * qblk_size);
776
+ for (int j = 0; j < qk; j++) {
777
+ q[j] = qs[j];
778
+ }
779
+ }
780
+
781
+ // Repack the scales
782
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
783
+ // the last block is truncated and overriden by the scales.
784
+ for (int i = 0; i < nb; i++) {
785
+ // Repack the scales
786
+ ggml_half * d = (ggml_half *) (y_d + i * dblk_size);
787
+ d[0] = x[i * 8 + 0].d;
788
+ d[1] = x[i * 8 + 1].d;
789
+ d[2] = x[i * 8 + 2].d;
790
+ d[3] = x[i * 8 + 3].d;
791
+ d[4] = x[i * 8 + 4].d;
792
+ d[5] = x[i * 8 + 5].d;
793
+ d[6] = x[i * 8 + 6].d;
794
+ d[7] = x[i * 8 + 7].d;
795
+ }
796
+
797
+ if (opt_verbose > 1) {
798
+ for (int i = 0; i < nb; i++) {
799
+ dump_packed_block_q8x4x2(y, i, k);
800
+ }
801
+ }
802
+ }
803
+
804
+ static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) {
805
+ static const int qk = QK_Q8_0x4x2;
806
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
807
+
808
+ const int dblk_size = 8 * 2; // 8x __fp16
809
+ const int qblk_size = qk; // int8
810
+ const int qrow_size = k; // int8 (not padded to blocks)
811
+
812
+ const uint8_t * y_q = y + 0; // quants first
813
+ const uint8_t * y_d = y + qrow_size; // then scales
814
+
815
+ if (opt_verbose > 1) {
816
+ for (int i = 0; i < nb; i++) {
817
+ dump_packed_block_q8x4x2(y, i, k);
818
+ }
819
+ }
820
+
821
+ // Unpack the quants
822
+ for (int i = 0; i < nb; i++) {
823
+ uint8_t qs[QK_Q4_0x4x2]; // unpacked quants
824
+
825
+ const uint8_t * q = y_q + (i * qblk_size);
826
+ for (int j = 0; j < qk; j++) {
827
+ qs[j] = q[j];
828
+ }
829
+
830
+ pack_q8_0_quants(&x[i * 8 + 0], qs, 0);
831
+ pack_q8_0_quants(&x[i * 8 + 1], qs, 1);
832
+ pack_q8_0_quants(&x[i * 8 + 2], qs, 2);
833
+ pack_q8_0_quants(&x[i * 8 + 3], qs, 3);
834
+ pack_q8_0_quants(&x[i * 8 + 4], qs, 4);
835
+ pack_q8_0_quants(&x[i * 8 + 5], qs, 5);
836
+ pack_q8_0_quants(&x[i * 8 + 6], qs, 6);
837
+ pack_q8_0_quants(&x[i * 8 + 7], qs, 7);
838
+ }
839
+
840
+ // Repack the scales
841
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
842
+ // the last block is truncated and overriden by the scales.
843
+ for (int i = 0; i < nb; i++) {
844
+ // Unpack the scales
845
+ const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);
846
+ x[i * 8 + 0].d = d[0];
847
+ x[i * 8 + 1].d = d[1];
848
+ x[i * 8 + 2].d = d[2];
849
+ x[i * 8 + 3].d = d[3];
850
+ x[i * 8 + 4].d = d[4];
851
+ x[i * 8 + 5].d = d[5];
852
+ x[i * 8 + 6].d = d[6];
853
+ x[i * 8 + 7].d = d[7];
854
+ }
855
+
856
+ if (opt_verbose > 2) {
857
+ for (int i = 0; i < nb; i++) {
858
+ dump_block_q8_0(&x[i * 8 + 0], 0);
859
+ dump_block_q8_0(&x[i * 8 + 1], 1);
860
+ dump_block_q8_0(&x[i * 8 + 2], 2);
861
+ dump_block_q8_0(&x[i * 8 + 3], 3);
862
+ dump_block_q8_0(&x[i * 8 + 4], 4);
863
+ dump_block_q8_0(&x[i * 8 + 5], 5);
864
+ dump_block_q8_0(&x[i * 8 + 6], 6);
865
+ dump_block_q8_0(&x[i * 8 + 7], 7);
866
+ }
867
+ }
868
+ }
869
+
870
+ static void init_row_q8x4x2(block_q8_0 * x, int64_t k) {
871
+ static const int qk = QK_Q8_0x4x2;
872
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
873
+
874
+ // Init the quants such that they unpack into zeros
875
+ uint8_t qs[QK_Q8_0x4x2]; // unpacked quants
876
+ memset(qs, 0, sizeof(qs));
877
+
878
+ for (int i = 0; i < nb; i++) {
879
+ pack_q8_0_quants(&x[i * 8 + 0], qs, 0);
880
+ pack_q8_0_quants(&x[i * 8 + 1], qs, 1);
881
+ pack_q8_0_quants(&x[i * 8 + 2], qs, 2);
882
+ pack_q8_0_quants(&x[i * 8 + 3], qs, 3);
883
+ pack_q8_0_quants(&x[i * 8 + 4], qs, 4);
884
+ pack_q8_0_quants(&x[i * 8 + 5], qs, 5);
885
+ pack_q8_0_quants(&x[i * 8 + 6], qs, 6);
886
+ pack_q8_0_quants(&x[i * 8 + 7], qs, 7);
887
+ }
888
+
889
+ // Init the scales
890
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2)
891
+ // the last block is truncated and overriden by the scales.
892
+ for (int i = 0; i < nb; i++) {
893
+ // Unpack the scales
894
+ x[i * 8 + 0].d = 0;
895
+ x[i * 8 + 1].d = 0;
896
+ x[i * 8 + 2].d = 0;
897
+ x[i * 8 + 3].d = 0;
898
+ x[i * 8 + 4].d = 0;
899
+ x[i * 8 + 5].d = 0;
900
+ x[i * 8 + 6].d = 0;
901
+ x[i * 8 + 7].d = 0;
902
+ }
903
+ }
904
+
905
+ // repack q8_0 data into q8x4x2 tensor
906
+ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) {
907
+ int64_t nrows = ggml_nrows(t);
908
+
909
+ size_t row_size = ggml_row_size(t->type, t->ne[0]);
910
+ size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
911
+ size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
912
+
913
+ // Ensure we don't try to read more data than is available in the source buffer 'data'
914
+ // or write more than the tensor can hold.
915
+ const size_t total_tensor_size = (size_t)nrows * row_size;
916
+ const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
917
+
918
+ // Calculate how many full rows and how many remaining bytes we need to process.
919
+ const int64_t n_full_rows = n_bytes_to_copy / row_size;
920
+ const size_t n_rem_bytes = n_bytes_to_copy % row_size;
921
+
922
+ void * buf_pd = ggml_aligned_malloc(row_size_pd);
923
+ GGML_ASSERT(buf_pd != NULL);
924
+
925
+ void * buf_rp = ggml_aligned_malloc(row_size_rp);
926
+ GGML_ASSERT(buf_rp != NULL);
927
+
928
+ HEX_VERBOSE("ggml-hex: repack-q8_0-q8x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
929
+ t->ne[0], nrows, row_size);
930
+
931
+ init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
932
+
933
+ // 1. Process all the full rows
934
+ for (int64_t i = 0; i < n_full_rows; i++) {
935
+ const uint8_t * src = (const uint8_t *) data + (i * row_size);
936
+ uint8_t * dst = (uint8_t *) t->data + (i * row_size);
937
+
938
+ memcpy(buf_pd, src, row_size);
939
+ repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);
940
+ memcpy(dst, buf_rp, row_size);
941
+ }
942
+
943
+ // 2. Process the final, potentially partial, row
944
+ if (n_rem_bytes > 0) {
945
+ const int64_t i = n_full_rows;
946
+ const uint8_t * src = (const uint8_t *) data + (i * row_size);
947
+ uint8_t * dst = (uint8_t *) t->data + (i * row_size);
948
+
949
+ // re-init the row because we are potentially copying a partial row
950
+ init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]);
951
+
952
+ // Copy only the remaining bytes from the source.
953
+ memcpy(buf_pd, src, n_rem_bytes);
954
+
955
+ // Repack the entire buffer
956
+ repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]);
957
+
958
+ // Write only the corresponding remaining bytes to the destination tensor.
959
+ memcpy(dst, buf_rp, n_rem_bytes);
960
+ }
961
+
962
+ ggml_aligned_free(buf_pd, row_size_pd);
963
+ ggml_aligned_free(buf_rp, row_size_rp);
964
+ }
965
+
966
+ // repack q8x4x2 tensor into q8_0 data
967
+ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) {
968
+ int64_t nrows = ggml_nrows(t);
969
+
970
+ size_t row_size = ggml_row_size(t->type, t->ne[0]);
971
+ size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad
972
+ size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
973
+
974
+ // Ensure we don't try to copy more data than the tensor actually contains.
975
+ const size_t total_tensor_size = (size_t)nrows * row_size;
976
+ const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
977
+
978
+ // Calculate how many full rows and how many remaining bytes we need to process.
979
+ const int64_t n_full_rows = n_bytes_to_copy / row_size;
980
+ const size_t n_rem_bytes = n_bytes_to_copy % row_size;
981
+
982
+ void * buf_pd = ggml_aligned_malloc(row_size_pd);
983
+ GGML_ASSERT(buf_pd != NULL);
984
+
985
+ void * buf_rp = ggml_aligned_malloc(row_size_rp);
986
+ GGML_ASSERT(buf_rp != NULL);
987
+
988
+ HEX_VERBOSE("ggml-hex: repack-q8x4x2-q8_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size,
989
+ t->ne[0], nrows, row_size);
990
+
991
+ memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
992
+
993
+ // 1. Process all the full rows
994
+ for (int64_t i = 0; i < n_full_rows; i++) {
995
+ const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
996
+ uint8_t * dst = (uint8_t *) data + (i * row_size);
997
+
998
+ memcpy(buf_pd, src, row_size);
999
+ unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
1000
+ memcpy(dst, buf_rp, row_size);
1001
+ }
1002
+
1003
+ // 2. Process the final, potentially partial, row
1004
+ if (n_rem_bytes > 0) {
1005
+ const int64_t i = n_full_rows;
1006
+ const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
1007
+ uint8_t * dst = (uint8_t *) data + (i * row_size);
1008
+
1009
+ // We still need to read and unpack the entire source row because quantization is block-based.
1010
+ memcpy(buf_pd, src, row_size);
1011
+ unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
1012
+
1013
+ // But we only copy the remaining number of bytes to the destination.
1014
+ memcpy(dst, buf_rp, n_rem_bytes);
1015
+ }
1016
+
1017
+ ggml_aligned_free(buf_pd, row_size_pd);
1018
+ ggml_aligned_free(buf_rp, row_size_rp);
1019
+ }
1020
+
1021
+ // ======== MXFP4x4x2 ====================
1022
+ struct x2_mxfp4 {
1023
+ int v[2];
1024
+ };
1025
+
1026
+ static x2_mxfp4 unpack_mxfp4(uint8_t v) {
1027
+ x2_mxfp4 x;
1028
+ x.v[0] = kvalues_mxfp4[(v & 0x0f)];
1029
+ x.v[1] = kvalues_mxfp4[(v >> 4)];
1030
+ return x;
1031
+ }
1032
+
1033
+ static void dump_block_mxfp4(const block_mxfp4 * b, int i) {
1034
+ HEX_VERBOSE("ggml-hex: repack mxfp4 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_mxfp4(b->qs[0]).v[0],
1035
+ unpack_mxfp4(b->qs[1]).v[0], unpack_mxfp4(b->qs[2]).v[0], unpack_mxfp4(b->qs[3]).v[0],
1036
+ unpack_mxfp4(b->qs[12]).v[1], unpack_mxfp4(b->qs[13]).v[1], unpack_mxfp4(b->qs[14]).v[1],
1037
+ unpack_mxfp4(b->qs[15]).v[1], GGML_E8M0_TO_FP32_HALF(b->e));
1038
+ }
1039
+
1040
+ static void dump_packed_block_mxfp4x4x2(const uint8_t * v, unsigned int i, size_t k) {
1041
+ static const int qk = QK_MXFP4x4x2;
1042
+ const int eblk_size = 8 * 1; // 8x E8M0
1043
+ const int qblk_size = qk / 2; // int4
1044
+ const int qrow_size = k / 2; // int4 (not padded)
1045
+
1046
+ const uint8_t * v_q = v + 0; // quants first
1047
+ const uint8_t * v_e = v + qrow_size; // then scales
1048
+
1049
+ const uint8_t * q = v_q + i * qblk_size;
1050
+ const uint8_t * e = (const uint8_t *) (v_e + i * eblk_size);
1051
+
1052
+ HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i,
1053
+ unpack_mxfp4(q[0]).v[0], unpack_mxfp4(q[1]).v[0], unpack_mxfp4(q[2]).v[0], unpack_mxfp4(q[3]).v[0],
1054
+ unpack_mxfp4(q[60]).v[0], unpack_mxfp4(q[61]).v[0], unpack_mxfp4(q[62]).v[0], unpack_mxfp4(q[63]).v[0],
1055
+ unpack_mxfp4(q[124]).v[0], unpack_mxfp4(q[125]).v[0], unpack_mxfp4(q[126]).v[0],
1056
+ unpack_mxfp4(q[127]).v[0], GGML_E8M0_TO_FP32_HALF(e[0]), GGML_E8M0_TO_FP32_HALF(e[1]),
1057
+ GGML_E8M0_TO_FP32_HALF(e[2]), GGML_E8M0_TO_FP32_HALF(e[3]));
1058
+
1059
+ HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n",
1060
+ i + 1, unpack_mxfp4(q[0]).v[1], unpack_mxfp4(q[1]).v[1], unpack_mxfp4(q[2]).v[1],
1061
+ unpack_mxfp4(q[3]).v[1], unpack_mxfp4(q[60]).v[1], unpack_mxfp4(q[61]).v[1], unpack_mxfp4(q[62]).v[1],
1062
+ unpack_mxfp4(q[63]).v[1], unpack_mxfp4(q[124]).v[1], unpack_mxfp4(q[125]).v[1],
1063
+ unpack_mxfp4(q[126]).v[1], unpack_mxfp4(q[127]).v[1], GGML_E8M0_TO_FP32_HALF(e[4]),
1064
+ GGML_E8M0_TO_FP32_HALF(e[5]), GGML_E8M0_TO_FP32_HALF(e[6]), GGML_E8M0_TO_FP32_HALF(e[7]));
1065
+ }
1066
+
1067
+ static void unpack_mxfp4_quants(uint8_t * qs, const block_mxfp4 * x, unsigned int bi) {
1068
+ static const int qk = QK_MXFP4;
1069
+
1070
+ for (unsigned int i = 0; i < qk / 2; ++i) {
1071
+ const uint8_t x0 = (x->qs[i] & 0x0F);
1072
+ const uint8_t x1 = (x->qs[i] >> 4);
1073
+ qs[bi * qk + i + 0] = x0;
1074
+ qs[bi * qk + i + qk / 2] = x1;
1075
+ }
1076
+ }
1077
+
1078
+ static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int bi) {
1079
+ static const int qk = QK4_0;
1080
+
1081
+ for (unsigned int i = 0; i < qk / 2; ++i) {
1082
+ const uint8_t x0 = qs[bi * qk + i + 0];
1083
+ const uint8_t x1 = qs[bi * qk + i + qk / 2];
1084
+ x->qs[i] = x0 | (x1 << 4);
1085
+ }
1086
+ }
1087
+
1088
+ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) {
1089
+ static const int qk = QK_MXFP4x4x2;
1090
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
1091
+
1092
+ const int eblk_size = 8 * 1; // 8x E8M0
1093
+ const int qblk_size = qk / 2; // int4
1094
+ const int qrow_size = k / 2; // int4 (not padded to blocks)
1095
+
1096
+ uint8_t * y_q = y + 0; // quants first
1097
+ uint8_t * y_e = y + qrow_size; // then scales
1098
+
1099
+ if (opt_verbose > 2) {
1100
+ for (int i = 0; i < nb; i++) {
1101
+ dump_block_mxfp4(&x[i * 8 + 0], 0);
1102
+ dump_block_mxfp4(&x[i * 8 + 1], 1);
1103
+ dump_block_mxfp4(&x[i * 8 + 2], 2);
1104
+ dump_block_mxfp4(&x[i * 8 + 3], 3);
1105
+ dump_block_mxfp4(&x[i * 8 + 4], 4);
1106
+ dump_block_mxfp4(&x[i * 8 + 5], 5);
1107
+ dump_block_mxfp4(&x[i * 8 + 6], 6);
1108
+ dump_block_mxfp4(&x[i * 8 + 7], 7);
1109
+ }
1110
+ }
1111
+
1112
+ // Repack the quants
1113
+ for (int i = 0; i < nb; i++) {
1114
+ uint8_t qs[QK_MXFP4x4x2]; // unpacked quants
1115
+
1116
+ unpack_mxfp4_quants(qs, &x[i * 8 + 0], 0);
1117
+ unpack_mxfp4_quants(qs, &x[i * 8 + 1], 1);
1118
+ unpack_mxfp4_quants(qs, &x[i * 8 + 2], 2);
1119
+ unpack_mxfp4_quants(qs, &x[i * 8 + 3], 3);
1120
+ unpack_mxfp4_quants(qs, &x[i * 8 + 4], 4);
1121
+ unpack_mxfp4_quants(qs, &x[i * 8 + 5], 5);
1122
+ unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6);
1123
+ unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7);
1124
+
1125
+ uint8_t * q = y_q + (i * qblk_size);
1126
+ for (int j = 0; j < qk / 2; j++) {
1127
+ q[j] = (qs[j + 128] << 4) | qs[j];
1128
+ }
1129
+ }
1130
+
1131
+ // Repack the scales
1132
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)
1133
+ // the last block is truncated and overriden by the scales.
1134
+ for (int i = 0; i < nb; i++) {
1135
+ // Repack the scales
1136
+ uint8_t * e = (uint8_t *) (y_e + i * eblk_size);
1137
+ e[0] = x[i * 8 + 0].e;
1138
+ e[1] = x[i * 8 + 1].e;
1139
+ e[2] = x[i * 8 + 2].e;
1140
+ e[3] = x[i * 8 + 3].e;
1141
+ e[4] = x[i * 8 + 4].e;
1142
+ e[5] = x[i * 8 + 5].e;
1143
+ e[6] = x[i * 8 + 6].e;
1144
+ e[7] = x[i * 8 + 7].e;
1145
+ }
1146
+
1147
+ if (opt_verbose > 1) {
1148
+ for (int i = 0; i < nb; i++) {
1149
+ dump_packed_block_mxfp4x4x2(y, i, k);
1150
+ }
1151
+ }
1152
+ }
1153
+
1154
+ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) {
1155
+ static const int qk = QK_MXFP4x4x2;
1156
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
1157
+
1158
+ const int eblk_size = 8 * 1; // 8x E8M0
1159
+ const int qblk_size = qk / 2; // int4
1160
+ const int qrow_size = k / 2; // int4 (not padded to blocks)
1161
+
1162
+ const uint8_t * y_q = y + 0; // quants first
1163
+ const uint8_t * y_e = y + qrow_size; // then scales
1164
+
1165
+ if (opt_verbose > 1) {
1166
+ for (int i = 0; i < nb; i++) {
1167
+ dump_packed_block_mxfp4x4x2(y, i, k);
1168
+ }
1169
+ }
1170
+
1171
+ // Unpack the quants
1172
+ for (int i = 0; i < nb; i++) {
1173
+ uint8_t qs[QK_MXFP4x4x2]; // unpacked quants
1174
+
1175
+ const uint8_t * q = y_q + (i * qblk_size);
1176
+ for (int j = 0; j < qk / 2; j++) {
1177
+ qs[j] = q[j] & 0xf;
1178
+ qs[j + 128] = q[j] >> 4;
1179
+ }
1180
+
1181
+ pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);
1182
+ pack_mxfp4_quants(&x[i * 8 + 1], qs, 1);
1183
+ pack_mxfp4_quants(&x[i * 8 + 2], qs, 2);
1184
+ pack_mxfp4_quants(&x[i * 8 + 3], qs, 3);
1185
+ pack_mxfp4_quants(&x[i * 8 + 4], qs, 4);
1186
+ pack_mxfp4_quants(&x[i * 8 + 5], qs, 5);
1187
+ pack_mxfp4_quants(&x[i * 8 + 6], qs, 6);
1188
+ pack_mxfp4_quants(&x[i * 8 + 7], qs, 7);
1189
+ }
1190
+
1191
+ // Repack the scales
1192
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2)
1193
+ // the last block is truncated and overriden by the scales.
1194
+ for (int i = 0; i < nb; i++) {
1195
+ // Unpack the scales
1196
+ const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size);
1197
+ x[i * 8 + 0].e = e[0];
1198
+ x[i * 8 + 1].e = e[1];
1199
+ x[i * 8 + 2].e = e[2];
1200
+ x[i * 8 + 3].e = e[3];
1201
+ x[i * 8 + 4].e = e[4];
1202
+ x[i * 8 + 5].e = e[5];
1203
+ x[i * 8 + 6].e = e[6];
1204
+ x[i * 8 + 7].e = e[7];
1205
+ }
1206
+
1207
+ if (opt_verbose > 2) {
1208
+ for (int i = 0; i < nb; i++) {
1209
+ dump_block_mxfp4(&x[i * 8 + 0], 0);
1210
+ dump_block_mxfp4(&x[i * 8 + 1], 1);
1211
+ dump_block_mxfp4(&x[i * 8 + 2], 2);
1212
+ dump_block_mxfp4(&x[i * 8 + 3], 3);
1213
+ dump_block_mxfp4(&x[i * 8 + 4], 4);
1214
+ dump_block_mxfp4(&x[i * 8 + 5], 5);
1215
+ dump_block_mxfp4(&x[i * 8 + 6], 6);
1216
+ dump_block_mxfp4(&x[i * 8 + 7], 7);
1217
+ }
1218
+ }
1219
+ }
1220
+
1221
+ static void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) {
1222
+ static const int qk = QK_MXFP4x4x2;
1223
+ const int nb = (k + qk - 1) / qk; // number of blocks (padded)
1224
+
1225
+ // Init the quants such that they unpack into zeros
1226
+ uint8_t qs[QK_MXFP4x4x2]; // unpacked quants
1227
+ memset(qs, 0, sizeof(qs));
1228
+
1229
+ for (int i = 0; i < nb; i++) {
1230
+ pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);
1231
+ pack_mxfp4_quants(&x[i * 8 + 1], qs, 1);
1232
+ pack_mxfp4_quants(&x[i * 8 + 2], qs, 2);
1233
+ pack_mxfp4_quants(&x[i * 8 + 3], qs, 3);
1234
+ pack_mxfp4_quants(&x[i * 8 + 4], qs, 4);
1235
+ pack_mxfp4_quants(&x[i * 8 + 5], qs, 5);
1236
+ pack_mxfp4_quants(&x[i * 8 + 6], qs, 6);
1237
+ pack_mxfp4_quants(&x[i * 8 + 7], qs, 7);
1238
+ }
1239
+
1240
+ // Init the scales
1241
+ // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)
1242
+ // the last block is truncated and overriden by the scales.
1243
+ for (int i = 0; i < nb; i++) {
1244
+ // Unpack the scales
1245
+ x[i * 8 + 0].e = 0;
1246
+ x[i * 8 + 1].e = 0;
1247
+ x[i * 8 + 2].e = 0;
1248
+ x[i * 8 + 3].e = 0;
1249
+ x[i * 8 + 4].e = 0;
1250
+ x[i * 8 + 5].e = 0;
1251
+ x[i * 8 + 6].e = 0;
1252
+ x[i * 8 + 7].e = 0;
1253
+ }
1254
+ }
1255
+
1256
+ // repack mxfp4 data into mxfp4x4x2 tensor
1257
+ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t size) {
1258
+ int64_t nrows = ggml_nrows(t);
1259
+
1260
+ size_t row_size = ggml_row_size(t->type, t->ne[0]);
1261
+ size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
1262
+ size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
1263
+
1264
+ // Ensure we don't try to read more data than is available in the source buffer 'data'
1265
+ // or write more than the tensor can hold.
1266
+ const size_t total_tensor_size = (size_t)nrows * row_size;
1267
+ const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
1268
+
1269
+ // Calculate how many full rows and how many remaining bytes we need to process.
1270
+ const int64_t n_full_rows = n_bytes_to_copy / row_size;
1271
+ const size_t n_rem_bytes = n_bytes_to_copy % row_size;
1272
+
1273
+ void * buf_pd = ggml_aligned_malloc(row_size_pd);
1274
+ GGML_ASSERT(buf_pd != NULL);
1275
+
1276
+ void * buf_rp = ggml_aligned_malloc(row_size_rp);
1277
+ GGML_ASSERT(buf_rp != NULL);
1278
+
1279
+ HEX_VERBOSE("ggml-hex: repack-mxfp4-mxfp4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data,
1280
+ size, t->ne[0], nrows, row_size);
1281
+
1282
+ init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros
1283
+
1284
+ // 1. Process all the full rows
1285
+ for (int64_t i = 0; i < n_full_rows; i++) {
1286
+ const uint8_t * src = (const uint8_t *) data + (i * row_size);
1287
+ uint8_t * dst = (uint8_t *) t->data + (i * row_size);
1288
+
1289
+ memcpy(buf_pd, src, row_size);
1290
+ repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);
1291
+ memcpy(dst, buf_rp, row_size);
1292
+ }
1293
+
1294
+ // 2. Process the final, potentially partial, row
1295
+ if (n_rem_bytes > 0) {
1296
+ const int64_t i = n_full_rows;
1297
+ const uint8_t * src = (const uint8_t *) data + (i * row_size);
1298
+ uint8_t * dst = (uint8_t *) t->data + (i * row_size);
1299
+
1300
+ // re-init the row because we are potentially copying a partial row
1301
+ init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]);
1302
+
1303
+ // Copy only the remaining bytes from the source.
1304
+ memcpy(buf_pd, src, n_rem_bytes);
1305
+
1306
+ // Repack the entire buffer (partial data + zero padding).
1307
+ repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]);
1308
+
1309
+ // Write only the corresponding remaining bytes to the destination tensor.
1310
+ memcpy(dst, buf_rp, n_rem_bytes);
1311
+ }
1312
+
1313
+ ggml_aligned_free(buf_pd, row_size_pd);
1314
+ ggml_aligned_free(buf_rp, row_size_rp);
1315
+ }
1316
+
1317
+ // repack mxfp4x4x2 tensor into mxfp4 data
1318
+ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t size) {
1319
+ int64_t nrows = ggml_nrows(t);
1320
+
1321
+ size_t row_size = ggml_row_size(t->type, t->ne[0]);
1322
+ size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad
1323
+ size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any)
1324
+
1325
+ // Ensure we don't try to copy more data than the tensor actually contains.
1326
+ const size_t total_tensor_size = (size_t)nrows * row_size;
1327
+ const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size;
1328
+
1329
+ // Calculate how many full rows and how many remaining bytes we need to process.
1330
+ const int64_t n_full_rows = n_bytes_to_copy / row_size;
1331
+ const size_t n_rem_bytes = n_bytes_to_copy % row_size;
1332
+
1333
+ void * buf_pd = ggml_aligned_malloc(row_size_pd);
1334
+ GGML_ASSERT(buf_pd != NULL);
1335
+
1336
+ void * buf_rp = ggml_aligned_malloc(row_size_rp);
1337
+ GGML_ASSERT(buf_rp != NULL);
1338
+
1339
+ HEX_VERBOSE("ggml-hex: repack-mxfp4x4x2-mxfp4 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data,
1340
+ size, t->ne[0], nrows, row_size);
1341
+
1342
+ memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros
1343
+
1344
+ // 1. Process all the full rows
1345
+ for (int64_t i = 0; i < n_full_rows; i++) {
1346
+ const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
1347
+ uint8_t * dst = (uint8_t *) data + (i * row_size);
1348
+
1349
+ memcpy(buf_pd, src, row_size);
1350
+ unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
1351
+ memcpy(dst, buf_rp, row_size);
1352
+ }
1353
+
1354
+ // 2. Process the final, potentially partial, row
1355
+ if (n_rem_bytes > 0) {
1356
+ const int64_t i = n_full_rows;
1357
+ const uint8_t * src = (const uint8_t *) t->data + (i * row_size);
1358
+ uint8_t * dst = (uint8_t *) data + (i * row_size);
1359
+
1360
+ // We still need to read and unpack the entire source row because the format is block-based.
1361
+ memcpy(buf_pd, src, row_size);
1362
+ unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]);
1363
+
1364
+ // But we only copy the remaining number of bytes to the destination to respect the size limit.
1365
+ memcpy(dst, buf_rp, n_rem_bytes);
1366
+ }
1367
+
1368
+ ggml_aligned_free(buf_pd, row_size_pd);
1369
+ ggml_aligned_free(buf_rp, row_size_rp);
1370
+ }
1371
+
1372
+ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
1373
+ ggml_tensor * tensor,
1374
+ const void * data,
1375
+ size_t offset,
1376
+ size_t size) {
1377
+ auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context;
1378
+ auto sess = ctx->sess;
1379
+
1380
+ HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data,
1381
+ offset, size);
1382
+
1383
+ switch (tensor->type) {
1384
+ case GGML_TYPE_Q4_0:
1385
+ GGML_ASSERT(offset == 0);
1386
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1387
+ repack_q4_0_q4x4x2(tensor, data, size);
1388
+ break;
1389
+
1390
+ case GGML_TYPE_Q8_0:
1391
+ GGML_ASSERT(offset == 0);
1392
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1393
+ repack_q8_0_q8x4x2(tensor, data, size);
1394
+ break;
1395
+
1396
+ case GGML_TYPE_MXFP4:
1397
+ GGML_ASSERT(offset == 0);
1398
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1399
+ repack_mxfp4_mxfp4x4x2(tensor, data, size);
1400
+ break;
1401
+
1402
+ default:
1403
+ memcpy((char *) tensor->data + offset, data, size);
1404
+ break;
1405
+ }
1406
+ }
1407
+
1408
+ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
1409
+ const ggml_tensor * tensor,
1410
+ void * data,
1411
+ size_t offset,
1412
+ size_t size) {
1413
+ auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context;
1414
+ auto sess = ctx->sess;
1415
+
1416
+ HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data,
1417
+ offset, size);
1418
+
1419
+ switch (tensor->type) {
1420
+ case GGML_TYPE_Q4_0:
1421
+ GGML_ASSERT(offset == 0);
1422
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1423
+ repack_q4x4x2_q4_0(data, tensor, size);
1424
+ break;
1425
+
1426
+ case GGML_TYPE_Q8_0:
1427
+ GGML_ASSERT(offset == 0);
1428
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1429
+ repack_q8x4x2_q8_0(data, tensor, size);
1430
+ break;
1431
+
1432
+ case GGML_TYPE_MXFP4:
1433
+ GGML_ASSERT(offset == 0);
1434
+ GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
1435
+ repack_mxfp4x4x2_mxfp4(data, tensor, size);
1436
+ break;
1437
+
1438
+ default:
1439
+ memcpy(data, (const char *) tensor->data + offset, size);
1440
+ break;
1441
+ }
1442
+ }
1443
+
1444
+ static bool ggml_backend_hexagon_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
1445
+ const struct ggml_tensor * src,
1446
+ struct ggml_tensor * dst) {
1447
+ GGML_UNUSED(buffer);
1448
+ GGML_UNUSED(src);
1449
+ GGML_UNUSED(dst);
1450
+ // we might optimize this later, for now take the slow path (ie get/set_tensor)
1451
+ return false;
1452
+ }
1453
+
1454
+ static void ggml_backend_hexagon_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1455
+ auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context;
1456
+ auto sess = ctx->sess;
1457
+ HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->name.c_str(), (void *) ctx->base, ctx->size);
1458
+ memset(ctx->base, value, ctx->size);
1459
+ }
1460
+
1461
+ static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = {
1462
+ /* .free_buffer = */ ggml_backend_hexagon_buffer_free_buffer,
1463
+ /* .get_base = */ ggml_backend_hexagon_buffer_get_base,
1464
+ /* .init_tensor = */ ggml_backend_hexagon_buffer_init_tensor,
1465
+ /* .memset_tensor = */ NULL,
1466
+ /* .set_tensor = */ ggml_backend_hexagon_buffer_set_tensor,
1467
+ /* .get_tensor = */ ggml_backend_hexagon_buffer_get_tensor,
1468
+ /* .cpy_tensor = */ ggml_backend_hexagon_buffer_cpy_tensor,
1469
+ /* .clear = */ ggml_backend_hexagon_buffer_clear,
1470
+ /* .reset = */ NULL,
1471
+ };
1472
+
1473
+ // ** backend buffer type
1474
+
1475
+ static const char * ggml_backend_hexagon_buffer_type_name(ggml_backend_buffer_type_t buffer_type) {
1476
+ return static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->name.c_str();
1477
+ }
1478
+
1479
+ static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer(
1480
+ ggml_backend_buffer_type_t buffer_type, size_t size) {
1481
+ auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
1482
+ try {
1483
+ ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, false /*repack*/);
1484
+ return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size);
1485
+ } catch (const std::exception & exc) {
1486
+ GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what());
1487
+ return nullptr;
1488
+ }
1489
+ }
1490
+
1491
+ static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffer(
1492
+ ggml_backend_buffer_type_t buffer_type, size_t size) {
1493
+ auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess;
1494
+ try {
1495
+ ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, true /*repack*/);
1496
+ return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size);
1497
+ } catch (const std::exception & exc) {
1498
+ GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what());
1499
+ return nullptr;
1500
+ }
1501
+ }
1502
+
1503
+ static size_t ggml_backend_hexagon_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {
1504
+ return 128; // HVX alignment
1505
+ GGML_UNUSED(buffer_type);
1506
+ }
1507
+
1508
+ static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * t) {
1509
+ return ggml_nbytes(t);
1510
+ }
1511
+
1512
+ static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {
1513
+ return 1 * 1024 * 1024 * 1024; // 1GB per buffer
1514
+ GGML_UNUSED(buffer_type);
1515
+ }
1516
+
1517
+ static bool ggml_backend_hexagon_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1518
+ return opt_hostbuf;
1519
+ GGML_UNUSED(buft);
1520
+ }
1521
+
1522
+ static bool ggml_backend_hexagon_repack_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1523
+ return false;
1524
+ GGML_UNUSED(buft);
1525
+ }
1526
+
1527
+ static ggml_backend_buffer_type_i ggml_backend_hexagon_buffer_type_interface = {
1528
+ /* .get_name = */ ggml_backend_hexagon_buffer_type_name,
1529
+ /* .alloc_buffer = */ ggml_backend_hexagon_buffer_type_alloc_buffer,
1530
+ /* .get_alignment = */ ggml_backend_hexagon_buffer_type_get_alignment,
1531
+ /* .get_max_size = */ ggml_backend_hexagon_buffer_type_get_max_size,
1532
+ /* .get_alloc_size = */ ggml_backend_hexagon_buffer_type_get_alloc_size,
1533
+ /* .is_host = */ ggml_backend_hexagon_buffer_type_is_host,
1534
+ };
1535
+
1536
+ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interface = {
1537
+ /* .get_name = */ ggml_backend_hexagon_buffer_type_name,
1538
+ /* .alloc_buffer = */ ggml_backend_hexagon_repack_buffer_type_alloc_buffer,
1539
+ /* .get_alignment = */ ggml_backend_hexagon_buffer_type_get_alignment,
1540
+ /* .get_max_size = */ ggml_backend_hexagon_buffer_type_get_max_size,
1541
+ /* .get_alloc_size = */ ggml_backend_hexagon_buffer_type_get_alloc_size,
1542
+ /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host,
1543
+ };
1544
+
1545
+ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) {
1546
+ this->valid_session = false;
1547
+ this->valid_handle = false;
1548
+ this->valid_queue = false;
1549
+ this->valid_iface = false;
1550
+
1551
+ this->domain_id = 3; // Default for CDSP, updated after the session is created
1552
+ this->session_id = 0; // Default for CDSP, updated after the session is created
1553
+ this->dev_id = dev_id;
1554
+ this->name = std::string("HTP") + std::to_string(dev_id);
1555
+
1556
+ this->op_pending = 0;
1557
+ this->prof_usecs = 0;
1558
+ this->prof_cycles = 0;
1559
+ this->prof_pkts = 0;
1560
+
1561
+ GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str());
1562
+
1563
+ domain * my_domain = get_domain(this->domain_id);
1564
+ if (my_domain == NULL) {
1565
+ GGML_LOG_ERROR("ggml-hex: unable to get domain struct for CDSP\n");
1566
+ throw std::runtime_error("ggml-hex: failed to get CDSP domain (see log for details)");
1567
+ }
1568
+
1569
+ // Create new session
1570
+ if (dev_id != 0) {
1571
+ struct remote_rpc_reserve_new_session n;
1572
+ n.domain_name_len = strlen(CDSP_DOMAIN_NAME);
1573
+ n.domain_name = const_cast<char *>(CDSP_DOMAIN_NAME);
1574
+ n.session_name = const_cast<char *>(this->name.c_str());
1575
+ n.session_name_len = this->name.size();
1576
+
1577
+ int err = remote_session_control(FASTRPC_RESERVE_NEW_SESSION, (void *) &n, sizeof(n));
1578
+ if (err != AEE_SUCCESS) {
1579
+ GGML_LOG_ERROR("ggml-hex: failed to reserve new session %d : error 0x%x\n", dev_id, err);
1580
+ throw std::runtime_error("ggml-hex: remote_session_control(new-sess) failed (see log for details)");
1581
+ }
1582
+
1583
+ // Save the IDs
1584
+ this->session_id = n.session_id;
1585
+ this->domain_id = n.effective_domain_id;
1586
+ this->valid_session = true;
1587
+ }
1588
+
1589
+ // Get session URI
1590
+
1591
+ char session_uri[256];
1592
+ {
1593
+ char htp_uri[256];
1594
+ snprintf(htp_uri, sizeof(htp_uri), "file:///libggml-htp-v%u.so?htp_iface_skel_handle_invoke&_modver=1.0", opt_arch);
1595
+
1596
+ struct remote_rpc_get_uri u = {};
1597
+ u.session_id = this->session_id;
1598
+ u.domain_name = const_cast<char *>(CDSP_DOMAIN_NAME);
1599
+ u.domain_name_len = strlen(CDSP_DOMAIN_NAME);
1600
+ u.module_uri = const_cast<char *>(htp_uri);
1601
+ u.module_uri_len = strlen(htp_uri);
1602
+ u.uri = session_uri;
1603
+ u.uri_len = sizeof(session_uri);
1604
+
1605
+ int err = remote_session_control(FASTRPC_GET_URI, (void *) &u, sizeof(u));
1606
+ if (err != AEE_SUCCESS) {
1607
+ // fallback to single session uris
1608
+ int htp_URI_domain_len = strlen(htp_uri) + MAX_DOMAIN_NAMELEN;
1609
+
1610
+ snprintf(session_uri, htp_URI_domain_len, "%s%s", htp_uri, my_domain->uri);
1611
+
1612
+ GGML_LOG_WARN("ggml-hex: failed to get URI for session %d : error 0x%x. Falling back to single session URI: %s\n", dev_id, err, session_uri);
1613
+ }
1614
+ }
1615
+
1616
+ // Enable Unsigned PD
1617
+ {
1618
+ struct remote_rpc_control_unsigned_module u;
1619
+ u.domain = this->domain_id;
1620
+ u.enable = 1;
1621
+ int err = remote_session_control(DSPRPC_CONTROL_UNSIGNED_MODULE, (void *) &u, sizeof(u));
1622
+ if (err != AEE_SUCCESS) {
1623
+ GGML_LOG_ERROR("ggml-hex: failed to enable unsigned PD for session %d : error 0x%x\n", dev_id, err);
1624
+ throw std::runtime_error("ggml-hex: remote_session_control(unsign) failed (see log for details)");
1625
+ }
1626
+ }
1627
+
1628
+ // Open session
1629
+ int err = htp_iface_open(session_uri, &this->handle);
1630
+ if (err != AEE_SUCCESS) {
1631
+ GGML_LOG_ERROR("ggml-hex: failed to open session %d : error 0x%x\n", dev_id, err);
1632
+ throw std::runtime_error("ggml-hex: failed to open session (see log for details)");
1633
+ }
1634
+
1635
+ this->valid_handle = true;
1636
+
1637
+ GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(),
1638
+ this->session_id, this->domain_id, session_uri, (unsigned long) this->handle);
1639
+
1640
+ // Enable FastRPC QoS mode
1641
+ {
1642
+ struct remote_rpc_control_latency l;
1643
+ l.enable = 1;
1644
+
1645
+ int err = remote_handle64_control(this->handle, DSPRPC_CONTROL_LATENCY, (void *) &l, sizeof(l));
1646
+ if (err != 0) {
1647
+ GGML_LOG_WARN("ggml-hex: failed to enable fastrpc QOS mode: 0x%08x\n", (unsigned) err);
1648
+ }
1649
+ }
1650
+
1651
+ // Now let's setup the DSP queue
1652
+ err = dspqueue_create(this->domain_id,
1653
+ 0, // Flags
1654
+ 128 * 1024, // Request queue size (in bytes)
1655
+ 64 * 1024, // Response queue size (in bytes)
1656
+ nullptr, // Read packet callback (we handle reads explicitly)
1657
+ nullptr, // Error callback (we handle errors during reads)
1658
+ (void *) this, // Callback context
1659
+ &queue);
1660
+ if (err != 0) {
1661
+ GGML_LOG_ERROR("ggml-hex: %s dspqueue_create failed: 0x%08x\n", this->name.c_str(), (unsigned) err);
1662
+ throw std::runtime_error("ggml-hex: failed to create dspqueue (see log for details)");
1663
+ }
1664
+
1665
+ this->valid_queue = true;
1666
+
1667
+ // Export queue for use on the DSP
1668
+ err = dspqueue_export(queue, &this->queue_id);
1669
+ if (err != 0) {
1670
+ GGML_LOG_ERROR("ggml-hex: dspqueue_export failed: 0x%08x\n", (unsigned) err);
1671
+ throw std::runtime_error("ggml-hex: dspqueue export failed (see log for details)");
1672
+ }
1673
+
1674
+ if (opt_etm) {
1675
+ err = htp_iface_enable_etm(this->handle);
1676
+ if (err != 0) {
1677
+ GGML_LOG_ERROR("ggml-hex: failed to enable ETM tracing: 0x%08x\n", (unsigned) err);
1678
+ }
1679
+ }
1680
+
1681
+ // Start the DSP-side service. We need to pass the queue ID to the
1682
+ // DSP in a FastRPC call; the DSP side will import the queue and start
1683
+ // listening for packets in a callback.
1684
+ err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx);
1685
+ if (err != 0) {
1686
+ GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err);
1687
+ throw std::runtime_error("ggml-hex: iface start failed (see log for details)");
1688
+ }
1689
+ this->valid_iface = true;
1690
+ }
1691
+
1692
+ void ggml_hexagon_session::release() noexcept(true) {
1693
+ GGML_LOG_INFO("ggml-hex: releasing session: %s\n", this->name.c_str());
1694
+
1695
+ int err;
1696
+
1697
+ // Stop the DSP-side service and close the queue
1698
+ if (this->valid_iface) {
1699
+ err = htp_iface_stop(this->handle);
1700
+ if (err != 0) {
1701
+ GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err);
1702
+ }
1703
+ }
1704
+
1705
+ if (opt_etm) {
1706
+ err = htp_iface_disable_etm(this->handle);
1707
+ if (err != 0) {
1708
+ GGML_LOG_ERROR("ggml-hex: warn : failed to disable ETM tracing: 0x%08x\n", (unsigned) err);
1709
+ }
1710
+ }
1711
+
1712
+ if (this->valid_queue) {
1713
+ err = dspqueue_close(queue);
1714
+ if (err != 0) {
1715
+ GGML_ABORT("ggml-hex: dspqueue_close failed: 0x%08x\n", (unsigned) err);
1716
+ }
1717
+ }
1718
+
1719
+ if (this->valid_handle) {
1720
+ htp_iface_close(this->handle);
1721
+ }
1722
+ }
1723
+
1724
+ ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) {
1725
+ buffer_type.device = dev;
1726
+ repack_buffer_type.device = dev;
1727
+
1728
+ try {
1729
+ allocate(dev_id);
1730
+
1731
+ buffer_type.iface = ggml_backend_hexagon_buffer_type_interface;
1732
+ buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name, this);
1733
+
1734
+ repack_buffer_type.iface = ggml_backend_hexagon_repack_buffer_type_interface;
1735
+ repack_buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name + "-REPACK", this);
1736
+ } catch (const std::exception & exc) {
1737
+ release();
1738
+ throw;
1739
+ }
1740
+ }
1741
+
1742
+ ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) {
1743
+ release();
1744
+
1745
+ delete static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type.context);
1746
+ delete static_cast<ggml_backend_hexagon_buffer_type_context *>(repack_buffer_type.context);
1747
+ }
1748
+
1749
+ // ** backend interface
1750
+
1751
+ static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) {
1752
+ return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment;
1753
+ }
1754
+
1755
+ static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) {
1756
+ return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
1757
+ }
1758
+
1759
+ static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) {
1760
+ if (x->ne[0] != y->ne[0]) {
1761
+ return false;
1762
+ }
1763
+ if (x->ne[1] != y->ne[1]) {
1764
+ return false;
1765
+ }
1766
+ if (x->ne[2] != y->ne[2]) {
1767
+ return false;
1768
+ }
1769
+ if (x->ne[3] != y->ne[3]) {
1770
+ return false;
1771
+ }
1772
+
1773
+ return true;
1774
+ }
1775
+
1776
+ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1777
+ const struct ggml_tensor * src0 = op->src[0];
1778
+ const struct ggml_tensor * src1 = op->src[1];
1779
+ const struct ggml_tensor * src2 = op->src[2];
1780
+ const struct ggml_tensor * src3 = op->src[3];
1781
+ const struct ggml_tensor * src4 = op->src[4];
1782
+ const struct ggml_tensor * dst = op;
1783
+
1784
+ // Check for F16 support only as requested
1785
+ if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) {
1786
+ return false;
1787
+ }
1788
+
1789
+ if (src3 && src3->type != GGML_TYPE_F16) { // mask
1790
+ return false;
1791
+ }
1792
+
1793
+ if (src4 && src4->type != GGML_TYPE_F32) { // sinks
1794
+ return false;
1795
+ }
1796
+
1797
+ // For now we support F32 or F16 output as htp backend often converts output on the fly if needed,
1798
+ // but the op implementation writes to F16 or F32.
1799
+ // Let's assume dst can be F32 or F16.
1800
+ if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {
1801
+ return false;
1802
+ }
1803
+
1804
+ return opt_experimental;
1805
+ }
1806
+
1807
+ static bool hex_supported_src0_type(ggml_type t) {
1808
+ return t == GGML_TYPE_F32;
1809
+ }
1810
+
1811
+ static bool hex_supported_src1_type(ggml_type t) {
1812
+ return t == GGML_TYPE_F32;
1813
+ }
1814
+
1815
+ static bool hex_supported_src2_type(ggml_type t) {
1816
+ return t == GGML_TYPE_F32;
1817
+ }
1818
+
1819
+ static bool hex_supported_src1_type2(ggml_type t) {
1820
+ return t == GGML_TYPE_F16;
1821
+ }
1822
+
1823
+ static bool hex_supported_src1_type3(ggml_type t) {
1824
+ return t == GGML_TYPE_I32;
1825
+ }
1826
+
1827
+ static bool hex_supported_dst_type(ggml_type t) {
1828
+ return t == GGML_TYPE_F32;
1829
+ }
1830
+
1831
+ static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) {
1832
+ // TODO: support broadcast for ne[2 and 3]
1833
+ if (x->ne[0] != y->ne[0]) {
1834
+ return false;
1835
+ }
1836
+ if (x->ne[2] != y->ne[2]) {
1837
+ return false;
1838
+ }
1839
+ if (x->ne[3] != y->ne[3]) {
1840
+ return false;
1841
+ }
1842
+ return true;
1843
+ }
1844
+
1845
+ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
1846
+ const struct ggml_tensor * src0 = dst->src[0];
1847
+ const struct ggml_tensor * src1 = dst->src[1];
1848
+
1849
+ if (dst->type != GGML_TYPE_F32) {
1850
+ return false;
1851
+ }
1852
+
1853
+ if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
1854
+ return false;
1855
+ }
1856
+
1857
+ switch (src0->type) {
1858
+ case GGML_TYPE_Q4_0:
1859
+ case GGML_TYPE_Q8_0:
1860
+ case GGML_TYPE_MXFP4:
1861
+ if (src0->ne[0] % 32) {
1862
+ return false;
1863
+ }
1864
+
1865
+ if (src0->ne[1] > 16 * 1024) {
1866
+ return false; // typically the lm-head which would be too large for VTCM
1867
+ }
1868
+
1869
+ if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
1870
+ return false;
1871
+ }
1872
+
1873
+ // src0 (weights) must be repacked
1874
+ if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) {
1875
+ return false;
1876
+ }
1877
+ break;
1878
+
1879
+ case GGML_TYPE_F16:
1880
+ if (src0->nb[1] < src0->nb[0]) {
1881
+ GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n");
1882
+ return false;
1883
+ }
1884
+ break;
1885
+
1886
+ default:
1887
+ return false;
1888
+ }
1889
+
1890
+ return true;
1891
+ }
1892
+
1893
+ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1894
+ const struct ggml_tensor * src0 = op->src[0];
1895
+ const struct ggml_tensor * src1 = op->src[1];
1896
+ const struct ggml_tensor * src2 = op->src[2];
1897
+ const struct ggml_tensor * dst = op;
1898
+
1899
+ if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32 || src2->type != GGML_TYPE_I32) {
1900
+ return false;
1901
+ }
1902
+
1903
+ switch (src0->type) {
1904
+ case GGML_TYPE_Q4_0:
1905
+ case GGML_TYPE_Q8_0:
1906
+ case GGML_TYPE_MXFP4:
1907
+ if ((src0->ne[0] % 32)) {
1908
+ return false;
1909
+ }
1910
+
1911
+ // src0 (weights) must be repacked
1912
+ if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) {
1913
+ return false;
1914
+ }
1915
+ break;
1916
+
1917
+ default:
1918
+ return false;
1919
+ }
1920
+
1921
+ return true;
1922
+ }
1923
+
1924
+ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1925
+ const struct ggml_tensor * src0 = op->src[0];
1926
+ const struct ggml_tensor * src1 = op->src[1];
1927
+ const struct ggml_tensor * dst = op;
1928
+
1929
+ if (!hex_supported_src0_type(src0->type)) {
1930
+ return false;
1931
+ }
1932
+ if (!hex_supported_src1_type(src1->type)) {
1933
+ return false;
1934
+ }
1935
+ if (!hex_supported_dst_type(dst->type)) {
1936
+ return false;
1937
+ }
1938
+ if (!hex_supported_dims2(src0, dst)) {
1939
+ return false;
1940
+ }
1941
+ if (!ggml_can_repeat(src1, src0)) {
1942
+ return false;
1943
+ }
1944
+
1945
+ // TODO: add support for non-contigiuos tensors
1946
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
1947
+ return false;
1948
+ }
1949
+
1950
+ return true;
1951
+ }
1952
+
1953
+ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1954
+ const struct ggml_tensor * src0 = op->src[0];
1955
+ const struct ggml_tensor * src1 = op->src[1];
1956
+ const struct ggml_tensor * dst = op;
1957
+
1958
+ if (!hex_supported_src0_type(src0->type)) {
1959
+ return false;
1960
+ }
1961
+ if (!hex_supported_src1_type(src1->type)) {
1962
+ return false;
1963
+ }
1964
+ if (!hex_supported_dst_type(dst->type)) {
1965
+ return false;
1966
+ }
1967
+ if (!hex_supported_dims2(src0, dst)) {
1968
+ return false;
1969
+ }
1970
+
1971
+ // REVISIT: add support for non-contigiuos tensors
1972
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
1973
+ return false;
1974
+ }
1975
+
1976
+ return true;
1977
+ }
1978
+
1979
+ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
1980
+ const struct ggml_tensor * src0 = op->src[0];
1981
+ const struct ggml_tensor * dst = op;
1982
+
1983
+ if (!hex_supported_src0_type(src0->type)) {
1984
+ return false;
1985
+ }
1986
+ if (!hex_supported_dst_type(dst->type)) {
1987
+ return false;
1988
+ }
1989
+ if (!hex_supported_dims2(src0, dst)) {
1990
+ return false;
1991
+ }
1992
+
1993
+ // TODO: add support for non-contigiuos tensors
1994
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
1995
+ return false;
1996
+ }
1997
+
1998
+ return true;
1999
+ }
2000
+
2001
+ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,
2002
+ const struct ggml_tensor * op) {
2003
+ const struct ggml_tensor * src0 = op->src[0];
2004
+ const struct ggml_tensor * src1 = op->src[1];
2005
+ const struct ggml_tensor * dst = op;
2006
+
2007
+ if (!hex_supported_src0_type(src0->type)) {
2008
+ return false;
2009
+ }
2010
+ if (!hex_supported_dst_type(dst->type)) {
2011
+ return false;
2012
+ }
2013
+
2014
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
2015
+ return false;
2016
+ }
2017
+
2018
+ if (src1) {
2019
+ if (!hex_supported_src1_type(src1->type)) {
2020
+ return false;
2021
+ }
2022
+ if (!hex_supported_dims2(src0, src1)) {
2023
+ return false;
2024
+ }
2025
+ if (!ggml_is_contiguous(src1)) {
2026
+ return false;
2027
+ }
2028
+ }
2029
+
2030
+ return true;
2031
+ }
2032
+
2033
+ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2034
+ const struct ggml_tensor * src0 = op->src[0];
2035
+ const struct ggml_tensor * src1 = op->src[1];
2036
+ const struct ggml_tensor * src2 = op->src[2];
2037
+ const struct ggml_tensor * dst = op;
2038
+
2039
+ if (src2) {
2040
+ return false; // FIXME: add support for sinks
2041
+ }
2042
+
2043
+ if (!hex_supported_src0_type(src0->type)) {
2044
+ return false;
2045
+ }
2046
+ if (!hex_supported_dst_type(dst->type)) {
2047
+ return false;
2048
+ }
2049
+
2050
+ if (src1) {
2051
+ if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) {
2052
+ return false;
2053
+ }
2054
+ if (src0->ne[0] != src1->ne[0]) {
2055
+ return false;
2056
+ }
2057
+ if (src1->ne[1] < src0->ne[1]) {
2058
+ return false;
2059
+ }
2060
+ if (src0->ne[2] % src1->ne[2] != 0) {
2061
+ return false;
2062
+ }
2063
+ if (src0->ne[3] % src1->ne[3] != 0) {
2064
+ return false;
2065
+ }
2066
+ }
2067
+
2068
+ if (src1) {
2069
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
2070
+ return false;
2071
+ }
2072
+ } else {
2073
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
2074
+ return false;
2075
+ }
2076
+ }
2077
+
2078
+ return true;
2079
+ }
2080
+
2081
+ static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2082
+ const struct ggml_tensor * src0 = op->src[0]; // values
2083
+ const struct ggml_tensor * src1 = op->src[1]; // indices
2084
+ const struct ggml_tensor * dst = op;
2085
+
2086
+ if (src0->type != GGML_TYPE_F32) {
2087
+ return false;
2088
+ }
2089
+
2090
+ if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
2091
+ return false;
2092
+ }
2093
+
2094
+ if (dst->type != GGML_TYPE_F16) {
2095
+ return false;
2096
+ }
2097
+
2098
+ return true;
2099
+ }
2100
+
2101
+ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2102
+ const struct ggml_tensor * src0 = op->src[0]; // values
2103
+ const struct ggml_tensor * src1 = op->src[1]; // indices
2104
+ const struct ggml_tensor * dst = op;
2105
+
2106
+ if (src0->type != GGML_TYPE_F32) {
2107
+ return false;
2108
+ }
2109
+
2110
+ if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
2111
+ return false;
2112
+ }
2113
+
2114
+ if (dst->type != GGML_TYPE_F32) {
2115
+ return false;
2116
+ }
2117
+
2118
+ return true;
2119
+ }
2120
+
2121
+ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
2122
+ const int32_t * op_params = &op->op_params[0];
2123
+
2124
+ int mode = op_params[2];
2125
+
2126
+ if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) {
2127
+ return false;
2128
+ }
2129
+ if (mode & 1) {
2130
+ return false;
2131
+ }
2132
+
2133
+ const struct ggml_tensor * src0 = op->src[0];
2134
+ const struct ggml_tensor * src1 = op->src[1];
2135
+ const struct ggml_tensor * src2 = op->src[2];
2136
+ const struct ggml_tensor * dst = op;
2137
+
2138
+ if (!hex_supported_src0_type(src0->type)) {
2139
+ return false; // FIXME: add support for GGML_TYPE_F16 for src0
2140
+ }
2141
+ if (!hex_supported_dst_type(dst->type)) {
2142
+ return false;
2143
+ }
2144
+ if (!hex_supported_src1_type3(src1->type)) {
2145
+ return false;
2146
+ }
2147
+ if (src2) {
2148
+ if (!hex_supported_src2_type(src2->type)) {
2149
+ return false;
2150
+ }
2151
+ int n_dims = op_params[1];
2152
+ if (src2->ne[0] < (n_dims / 2)) {
2153
+ return false;
2154
+ }
2155
+ }
2156
+
2157
+ if (src2) {
2158
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(src2) ||
2159
+ !ggml_is_contiguous(dst)) {
2160
+ return false;
2161
+ }
2162
+ } else {
2163
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
2164
+ return false;
2165
+ }
2166
+ }
2167
+
2168
+ return true;
2169
+ }
2170
+
2171
+ enum dspqbuf_type {
2172
+ DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0,
2173
+ DSPQBUF_TYPE_CPU_WRITE_DSP_READ,
2174
+ DSPQBUF_TYPE_CONSTANT,
2175
+ };
2176
+
2177
+ static void dspqbuf_dump(dspqueue_buffer * d, const struct ggml_tensor * t, dspqbuf_type type) {
2178
+ if (opt_verbose < 2) return;
2179
+
2180
+ auto buf = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context);
2181
+ auto sess = buf->sess;
2182
+
2183
+ GGML_LOG_DEBUG("ggml-hex: %s dspqbuf : %s base-addr %p base-size %zu data %p offset %u size %u\n", sess->name.c_str(),
2184
+ t->name, (void *) buf->base, buf->size, (void *) d->ptr, (unsigned int) d->offset,
2185
+ (unsigned int) d->size);
2186
+ }
2187
+
2188
+ // Init hexagon tensor from GGML tensor and Hexagon buffer
2189
+ static void htp_req_tensor_init(htp_tensor * h, const ggml_tensor * t) {
2190
+ h->data = 0; // updated by the receiver
2191
+ h->type = t->type;
2192
+ h->ne[0] = t->ne[0];
2193
+ h->ne[1] = t->ne[1];
2194
+ h->ne[2] = t->ne[2];
2195
+ h->ne[3] = t->ne[3];
2196
+ h->nb[0] = t->nb[0];
2197
+ h->nb[1] = t->nb[1];
2198
+ h->nb[2] = t->nb[2];
2199
+ h->nb[3] = t->nb[3];
2200
+ }
2201
+
2202
+ static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_tensor * t, dspqbuf_type type) {
2203
+ if (!t) {
2204
+ return 0;
2205
+ }
2206
+
2207
+ auto buf = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context);
2208
+
2209
+ memset(d, 0, sizeof(*d));
2210
+ d->fd = buf->fd;
2211
+ d->ptr = t->data;
2212
+ d->offset = (uint8_t *) t->data - buf->base;
2213
+ d->size = ggml_nbytes(t);
2214
+
2215
+ if (!d->size) {
2216
+ // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty
2217
+ d->size = 64;
2218
+ }
2219
+
2220
+ switch (type) {
2221
+ case DSPQBUF_TYPE_DSP_WRITE_CPU_READ:
2222
+ // Flush CPU
2223
+ d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER;
2224
+ break;
2225
+ case DSPQBUF_TYPE_CPU_WRITE_DSP_READ:
2226
+ // Flush CPU, Invalidate DSP
2227
+ d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
2228
+ break;
2229
+ default:
2230
+ // Constant buffer, no cache maintenance
2231
+ d->flags = 0;
2232
+ break;
2233
+ }
2234
+
2235
+ htp_req_tensor_init(h, t);
2236
+
2237
+ dspqbuf_dump(d, t, type);
2238
+
2239
+ return 1;
2240
+ }
2241
+
2242
+ typedef size_t (*htp_req_init_func_t)(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * op);
2243
+
2244
+ template <htp_req_init_func_t _init_req_func>
2245
+ static inline void ggml_hexagon_dispatch_op(ggml_hexagon_session *sess, const struct ggml_tensor * op, uint32_t flags) {
2246
+ uint64_t t = ggml_time_us();
2247
+
2248
+ // Construct HTP request
2249
+ htp_general_req req;
2250
+ memset(&req, 0, sizeof(req));
2251
+
2252
+ req.flags = flags;
2253
+ if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) {
2254
+ req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
2255
+ }
2256
+ if (!(opt_opmask & HTP_OPMASK_COMPUTE)) {
2257
+ req.flags |= HTP_OPFLAGS_SKIP_COMPUTE;
2258
+ }
2259
+
2260
+ ggml_hexagon_dump_op_exec(sess->name, op, req.flags);
2261
+
2262
+ if ((opt_opmask & HTP_OPMASK_QUEUE)) {
2263
+ dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
2264
+ size_t n_bufs = _init_req_func(&req, bufs, op);
2265
+ sess->enqueue(req, bufs, n_bufs, opt_opsync);
2266
+ }
2267
+
2268
+ t = ggml_time_us() - t;
2269
+
2270
+ ggml_hexagon_dump_op_prof(sess->name, op, sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, t);
2271
+ }
2272
+
2273
+ template <bool _is_src0_constant>
2274
+ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2275
+ switch (t->op) {
2276
+ case GGML_OP_MUL_MAT:
2277
+ req->op = HTP_OP_MUL_MAT;
2278
+ break;
2279
+ case GGML_OP_MUL:
2280
+ req->op = HTP_OP_MUL;
2281
+ break;
2282
+ case GGML_OP_ADD:
2283
+ req->op = HTP_OP_ADD;
2284
+ break;
2285
+ case GGML_OP_SUB:
2286
+ req->op = HTP_OP_SUB;
2287
+ break;
2288
+ default:
2289
+ GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op);
2290
+ break;
2291
+ }
2292
+
2293
+ // src0: Weights (mulmat) or First Operand (binary op).
2294
+ // If constant (e.g. weights), no cache management is needed.
2295
+ // src1: Input Activations (mulmat) or Second Operand (binary op).
2296
+
2297
+ size_t n_bufs = 0;
2298
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2299
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2300
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2301
+
2302
+ return n_bufs;
2303
+ }
2304
+
2305
+ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2306
+ req->op = HTP_OP_GET_ROWS;
2307
+
2308
+ size_t n_bufs = 0;
2309
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2310
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2311
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2312
+
2313
+ return n_bufs;
2314
+ }
2315
+
2316
+ template <bool _is_src0_constant>
2317
+ static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2318
+ switch (t->op) {
2319
+ case GGML_OP_MUL_MAT_ID:
2320
+ req->op = HTP_OP_MUL_MAT_ID;
2321
+ break;
2322
+ case GGML_OP_ADD_ID:
2323
+ req->op = HTP_OP_ADD_ID;
2324
+ break;
2325
+ default:
2326
+ GGML_ABORT("ggml-hex: unsupported op: %d\n", t->op);
2327
+ }
2328
+
2329
+ // src0: Weights (mulmat) or Input Activations (other op).
2330
+ // If constant, no cache management is needed.
2331
+ // src1: Input Activations (mulmat) or Second Operand (binary op).
2332
+ // src2: Expert IDs (mulmat) or Activated Experts (other op).
2333
+
2334
+ size_t n_bufs = 0;
2335
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2336
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2337
+ n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2338
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2339
+
2340
+ return n_bufs;
2341
+ }
2342
+
2343
+ static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2344
+ req->op = HTP_OP_SET_ROWS;
2345
+
2346
+ size_t n_bufs = 0;
2347
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2348
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2349
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2350
+
2351
+ return n_bufs;
2352
+ }
2353
+
2354
+ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2355
+ memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2356
+
2357
+ bool supported = false;
2358
+
2359
+ switch (t->op) {
2360
+ case GGML_OP_RMS_NORM:
2361
+ req->op = HTP_OP_RMS_NORM;
2362
+ supported = true;
2363
+ break;
2364
+
2365
+ case GGML_OP_SCALE:
2366
+ req->op = HTP_OP_SCALE;
2367
+ supported = true;
2368
+ break;
2369
+
2370
+ case GGML_OP_UNARY:
2371
+ if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
2372
+ req->op = HTP_OP_UNARY_SILU;
2373
+ supported = true;
2374
+ } else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) {
2375
+ req->op = HTP_OP_UNARY_GELU;
2376
+ supported = true;
2377
+ }
2378
+ break;
2379
+
2380
+ case GGML_OP_GLU:
2381
+ if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU) {
2382
+ req->op = HTP_OP_GLU_SWIGLU;
2383
+ supported = true;
2384
+ } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
2385
+ req->op = HTP_OP_GLU_SWIGLU_OAI;
2386
+ supported = true;
2387
+ }
2388
+ break;
2389
+
2390
+ case GGML_OP_SOFT_MAX:
2391
+ req->op = HTP_OP_SOFTMAX;
2392
+ supported = true;
2393
+ break;
2394
+
2395
+ default:
2396
+ break;
2397
+ }
2398
+
2399
+ if (!supported) {
2400
+ GGML_ABORT("ggml-hex: unary : unsupported op: %d\n", t->op);
2401
+ }
2402
+
2403
+ size_t n_bufs = 0;
2404
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2405
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2406
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2407
+
2408
+ return n_bufs;
2409
+ }
2410
+
2411
+ static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2412
+ memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2413
+ req->op = HTP_OP_ROPE;
2414
+
2415
+ size_t n_bufs = 0;
2416
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2417
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2418
+ n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2419
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2420
+
2421
+ return n_bufs;
2422
+ }
2423
+
2424
+ static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
2425
+ memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
2426
+ req->op = HTP_OP_FLASH_ATTN_EXT;
2427
+
2428
+ size_t n_bufs = 0;
2429
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2430
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2431
+ n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2432
+ n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2433
+ n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
2434
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
2435
+
2436
+ return n_bufs;
2437
+ }
2438
+
2439
+ static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
2440
+ auto sess = static_cast<ggml_hexagon_session *>(backend->context);
2441
+ return sess->name.c_str();
2442
+ }
2443
+
2444
+ static void ggml_backend_hexagon_free(ggml_backend_t backend) {
2445
+ // we just need to delete the backend here
2446
+ // the sessions are allocated & freed as part of the registry
2447
+ delete backend;
2448
+ }
2449
+
2450
+ static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
2451
+ return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type) && ggml_is_quantized(op1->src[1]->type));
2452
+ }
2453
+
2454
+ static inline bool is_compute_op(ggml_tensor *node)
2455
+ {
2456
+ return !(ggml_op_is_empty(node->op) || ggml_is_empty(node));
2457
+ }
2458
+
2459
+ // scan the graph and figure out last compute op index
2460
+ static inline int last_compute_op(ggml_cgraph * graph) {
2461
+ int last = 0;
2462
+ for (int i = 0; i < graph->n_nodes; ++i) {
2463
+ if (is_compute_op(graph->nodes[i])) {
2464
+ last = i;
2465
+ }
2466
+ }
2467
+
2468
+ return last;
2469
+ }
2470
+
2471
+ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
2472
+ auto sess = static_cast<ggml_hexagon_session *>(backend->context);
2473
+
2474
+ HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->name.c_str(), graph->n_nodes);
2475
+
2476
+ const int last = last_compute_op(graph);
2477
+
2478
+ const struct ggml_tensor * prev_quant_op = nullptr; // prev executed op with quantizer
2479
+
2480
+ for (int i = 0; i < graph->n_nodes; ++i) {
2481
+ ggml_tensor * node = graph->nodes[i];
2482
+
2483
+ if (!is_compute_op(node)) {
2484
+ continue;
2485
+ }
2486
+
2487
+ uint32_t flags = 0;
2488
+
2489
+ // skip quantizer if src1 is reused
2490
+ if (op_reuse_src1(node, prev_quant_op)) {
2491
+ flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
2492
+ }
2493
+
2494
+ // ask for early notification for the last Op
2495
+ if (i == last) {
2496
+ flags |= HTP_OPFLAGS_EARLY_WAKEUP;
2497
+ }
2498
+
2499
+ switch (node->op) {
2500
+ case GGML_OP_MUL_MAT:
2501
+ if (ggml_is_quantized(node->src[0]->type)) {
2502
+ ggml_hexagon_dispatch_op<init_binary_req<true>>(sess, node, flags);
2503
+ } else {
2504
+ ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
2505
+ }
2506
+ prev_quant_op = node;
2507
+ break;
2508
+ case GGML_OP_MUL_MAT_ID:
2509
+ if (ggml_is_quantized(node->src[0]->type)) {
2510
+ ggml_hexagon_dispatch_op<init_binary_id_req<true>>(sess, node, flags);
2511
+ } else {
2512
+ ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
2513
+ }
2514
+ prev_quant_op = node;
2515
+ break;
2516
+ case GGML_OP_MUL:
2517
+ case GGML_OP_ADD:
2518
+ case GGML_OP_SUB:
2519
+ ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
2520
+ break;
2521
+ case GGML_OP_ADD_ID:
2522
+ ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
2523
+ break;
2524
+ case GGML_OP_RMS_NORM:
2525
+ case GGML_OP_SCALE:
2526
+ ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2527
+ break;
2528
+ case GGML_OP_UNARY:
2529
+ if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
2530
+ (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
2531
+ ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2532
+ }
2533
+ break;
2534
+ case GGML_OP_GLU:
2535
+ if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
2536
+ (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) {
2537
+ ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2538
+ }
2539
+ break;
2540
+ case GGML_OP_SOFT_MAX:
2541
+ ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
2542
+ break;
2543
+
2544
+ case GGML_OP_ROPE:
2545
+ ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags);
2546
+ break;
2547
+
2548
+ case GGML_OP_FLASH_ATTN_EXT:
2549
+ ggml_hexagon_dispatch_op<init_flash_attn_ext_req>(sess, node, flags);
2550
+ break;
2551
+
2552
+ case GGML_OP_SET_ROWS:
2553
+ ggml_hexagon_dispatch_op<init_set_rows_req>(sess, node, flags);
2554
+ break;
2555
+
2556
+ case GGML_OP_GET_ROWS:
2557
+ ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags);
2558
+ break;
2559
+
2560
+ default:
2561
+ GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
2562
+ }
2563
+ }
2564
+
2565
+ // Wait until all pending ops complete
2566
+ sess->flush();
2567
+
2568
+ return GGML_STATUS_SUCCESS;
2569
+ }
2570
+
2571
+ static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) {
2572
+ auto sess = static_cast<ggml_hexagon_session *>(backend->context);
2573
+
2574
+ HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str());
2575
+
2576
+ // Wait until all pending ops complete
2577
+ sess->flush();
2578
+ }
2579
+
2580
+ struct node_info {
2581
+ ggml_tensor * node;
2582
+
2583
+ std::vector<ggml_tensor *> fused;
2584
+
2585
+ ggml_op op() const {
2586
+ return node->op;
2587
+ }
2588
+
2589
+ const ggml_tensor * dst() const {
2590
+ return fused.empty() ? node : fused.back();
2591
+ }
2592
+
2593
+ const ggml_tensor * src0() const {
2594
+ return node->src[0];
2595
+ }
2596
+
2597
+ const ggml_tensor * src1() const {
2598
+ return node->src[1];
2599
+ }
2600
+
2601
+ bool is_empty() const {
2602
+ return ggml_op_is_empty(node->op);
2603
+ }
2604
+
2605
+ void add_fused(ggml_tensor * t) {
2606
+ fused.push_back(t);
2607
+ }
2608
+
2609
+ bool stackable() const {
2610
+ switch (this->op()) {
2611
+ case GGML_OP_MUL_MAT:
2612
+ case GGML_OP_MUL_MAT_ID:
2613
+ return ggml_is_quantized(this->src0()->type);
2614
+ default:
2615
+ return false;
2616
+ }
2617
+ }
2618
+
2619
+ bool same_input(const node_info& n) const {
2620
+ return n.src1() == this->src1();
2621
+ }
2622
+ };
2623
+
2624
+ static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<node_info> & nodes) {
2625
+ const int n = nodes.size();
2626
+
2627
+ std::vector<int> res;
2628
+ res.reserve(n);
2629
+
2630
+ std::vector<bool> used(n, false);
2631
+
2632
+ // The main goal here is to stack the MUL_MAT ops with the same src1 input.
2633
+ // This allows use to reuse dynamically quantized src1 in VTCM.
2634
+
2635
+ // TODO: the current version might do incorrect reodering in cases where quantized src0
2636
+ // input is an output of another Op.
2637
+
2638
+ for (int i0 = 0; i0 < n; i0++) {
2639
+ if (used[i0]) {
2640
+ continue;
2641
+ }
2642
+
2643
+ res.push_back(i0);
2644
+
2645
+ const auto & node0 = nodes[i0];
2646
+
2647
+ if (!node0.stackable()) {
2648
+ continue;
2649
+ }
2650
+
2651
+ // that many nodes forward to search for stackable nodes that can reuse VTCM
2652
+ constexpr int N_FORWARD = 8;
2653
+
2654
+ for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
2655
+ if (used[i1]) {
2656
+ continue;
2657
+ }
2658
+
2659
+ const auto & node1 = nodes[i1];
2660
+
2661
+ if (node1.stackable() && node1.same_input(node0)) {
2662
+ res.push_back(i1);
2663
+ used[i1] = true;
2664
+ }
2665
+ }
2666
+ }
2667
+
2668
+ return res;
2669
+ }
2670
+
2671
+ static void ggml_backend_hexagon_graph_optimize(ggml_backend_t backend, ggml_cgraph * gf) {
2672
+ const int n = gf->n_nodes;
2673
+
2674
+ constexpr int MAX_FUSE = 16;
2675
+
2676
+ enum ggml_op ops[MAX_FUSE];
2677
+
2678
+ std::vector<node_info> nodes;
2679
+ nodes.reserve(gf->n_nodes);
2680
+
2681
+ // fuse nodes:
2682
+ // we don't want to make reorders that break fusing, so we first pack all fusable tensors
2683
+ // and perform the reorder over the fused nodes. after the reorder is done, we unfuse
2684
+ for (int i = 0; i < n; i++) {
2685
+ node_info node = {
2686
+ /*.node =*/gf->nodes[i],
2687
+ /*.fused =*/{},
2688
+ };
2689
+
2690
+ // fuse only ops that start with these operations
2691
+ // can be expanded when needed
2692
+ if (node.op() == GGML_OP_ADD ||
2693
+ node.op() == GGML_OP_NORM ||
2694
+ node.op() == GGML_OP_RMS_NORM) {
2695
+ ops[0] = node.op();
2696
+
2697
+ int f = i + 1;
2698
+ while (f < n && f < i + MAX_FUSE) {
2699
+ // conservatively allow fusing only these ops
2700
+ // can be expanded when needed
2701
+ if (gf->nodes[f]->op != GGML_OP_ADD &&
2702
+ gf->nodes[f]->op != GGML_OP_MUL &&
2703
+ gf->nodes[f]->op != GGML_OP_NORM &&
2704
+ gf->nodes[f]->op != GGML_OP_RMS_NORM) {
2705
+ break;
2706
+ }
2707
+ ops[f - i] = gf->nodes[f]->op;
2708
+ f++;
2709
+ }
2710
+
2711
+ f -= i;
2712
+ for (; f > 1; f--) {
2713
+ if (ggml_can_fuse(gf, i, ops, f)) {
2714
+ break;
2715
+ }
2716
+ }
2717
+
2718
+ // add the fused tensors into the node info so we can unfuse them later
2719
+ for (int k = 1; k < f; k++) {
2720
+ ++i;
2721
+
2722
+ // the .dst() becomes the last fused tensor
2723
+ node.add_fused(gf->nodes[i]);
2724
+ }
2725
+ }
2726
+
2727
+ nodes.push_back(std::move(node));
2728
+ }
2729
+
2730
+ const auto order = ggml_hexagon_graph_optimize_reorder(nodes);
2731
+
2732
+ // unfuse
2733
+ {
2734
+ int j = 0;
2735
+ for (const auto i : order) {
2736
+ const auto & node = nodes[i];
2737
+
2738
+ gf->nodes[j++] = node.node;
2739
+
2740
+ for (auto * fused : node.fused) {
2741
+ gf->nodes[j++] = fused;
2742
+ }
2743
+ }
2744
+ }
2745
+ }
2746
+
2747
+ static struct ggml_backend_i hexagon_backend_i = {
2748
+ /* .get_name = */ ggml_backend_hexagon_name,
2749
+ /* .free = */ ggml_backend_hexagon_free,
2750
+ /* .set_tensor_async = */ NULL,
2751
+ /* .get_tensor_async = */ NULL,
2752
+ /* .cpy_tensor_async = */ NULL,
2753
+ /* .synchronize = */ ggml_backend_hexagon_synchronize,
2754
+ /* .graph_plan_create = */ NULL,
2755
+ /* .graph_plan_free = */ NULL,
2756
+ /* .graph_plan_update = */ NULL,
2757
+ /* .graph_plan_compute = */ NULL,
2758
+ /* .graph_compute = */ ggml_backend_hexagon_graph_compute,
2759
+ /* .event_record = */ NULL,
2760
+ /* .event_wait = */ NULL,
2761
+ /* .graph_optimize = */ ggml_backend_hexagon_graph_optimize,
2762
+ };
2763
+
2764
+ static ggml_guid_t ggml_backend_hexagon_guid() {
2765
+ static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49,
2766
+ 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11 };
2767
+ return &guid;
2768
+ }
2769
+
2770
+ bool ggml_backend_is_hexagon(ggml_backend_t backend) {
2771
+ return backend && backend->iface.get_name == ggml_backend_hexagon_name;
2772
+ }
2773
+
2774
+ // device interface
2775
+
2776
+ static ggml_backend_t ggml_backend_hexagon_device_init(ggml_backend_dev_t dev, const char * params) {
2777
+ auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2778
+
2779
+ return new ggml_backend{
2780
+ /* .guid = */ ggml_backend_hexagon_guid(),
2781
+ /* .interface = */ hexagon_backend_i,
2782
+ /* .device = */ dev,
2783
+ /* .context = */ sess,
2784
+ };
2785
+
2786
+ GGML_UNUSED(params);
2787
+ }
2788
+
2789
+ static const char * ggml_backend_hexagon_device_get_name(ggml_backend_dev_t dev) {
2790
+ auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2791
+ return sess->name.c_str();
2792
+
2793
+ GGML_UNUSED(dev);
2794
+ }
2795
+
2796
+ static const char * ggml_backend_hexagon_device_get_description(ggml_backend_dev_t dev) {
2797
+ return "Hexagon";
2798
+ GGML_UNUSED(dev);
2799
+ }
2800
+
2801
+ static void ggml_backend_hexagon_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2802
+ // ~2GB per session for now
2803
+ *free = 2ULL * 1024 * 1024 * 1024;
2804
+ *total = *free;
2805
+
2806
+ GGML_UNUSED(dev);
2807
+ }
2808
+
2809
+ static enum ggml_backend_dev_type ggml_backend_hexagon_device_get_type(ggml_backend_dev_t dev) {
2810
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
2811
+
2812
+ GGML_UNUSED(dev);
2813
+ }
2814
+
2815
+ static void ggml_backend_hexagon_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
2816
+ props->name = ggml_backend_hexagon_device_get_name(dev);
2817
+ props->description = ggml_backend_hexagon_device_get_description(dev);
2818
+ props->type = ggml_backend_hexagon_device_get_type(dev);
2819
+ ggml_backend_hexagon_device_get_memory(dev, &props->memory_free, &props->memory_total);
2820
+ props->caps = {
2821
+ /* .async = */ true,
2822
+ /* .host_buffer = */ (bool) opt_hostbuf,
2823
+ /* .buffer_from_host_ptr = */ false,
2824
+ /* .events = */ false,
2825
+ };
2826
+ }
2827
+
2828
+ static ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_buffer_type(ggml_backend_dev_t dev) {
2829
+ auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2830
+ return &sess->buffer_type;
2831
+ }
2832
+
2833
+ static ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_repack_buffer_type(ggml_backend_dev_t dev) {
2834
+ auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2835
+ return &sess->repack_buffer_type;
2836
+ }
2837
+
2838
+ static bool ggml_hexagon_supported_buffer(ggml_hexagon_session *sess, const struct ggml_tensor * t) {
2839
+ if (t && t->buffer) {
2840
+ if (ggml_backend_buffer_is_hexagon(t->buffer) == false) return false; // not our buffer
2841
+ if (ggml_backend_hexagon_buffer_get_sess(t->buffer) != sess) return false; // wrong session
2842
+ }
2843
+ return true;
2844
+ }
2845
+
2846
+ static bool ggml_hexagon_supported_buffers(ggml_hexagon_session *sess, const struct ggml_tensor * t) {
2847
+ // all srcs & dsts must be mapped to the same session
2848
+ if (!ggml_hexagon_supported_buffer(sess, t)) {
2849
+ return false;
2850
+ }
2851
+
2852
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
2853
+ if (!ggml_hexagon_supported_buffer(sess, t->src[i])) {
2854
+ return false;
2855
+ }
2856
+ }
2857
+
2858
+ return true;
2859
+ }
2860
+
2861
+ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
2862
+ auto sess = static_cast<ggml_hexagon_session *>(dev->context);
2863
+
2864
+ // all srcs & dsts must be mapped to the same session
2865
+ if (!ggml_hexagon_supported_buffers(sess, op)) {
2866
+ ggml_hexagon_dump_op_supp(sess->name, op, false);
2867
+ return false;
2868
+ }
2869
+
2870
+ bool supp = false;
2871
+ switch (op->op) {
2872
+ case GGML_OP_NONE:
2873
+ case GGML_OP_RESHAPE:
2874
+ case GGML_OP_VIEW:
2875
+ case GGML_OP_PERMUTE:
2876
+ case GGML_OP_TRANSPOSE:
2877
+ supp = true;
2878
+ break;
2879
+
2880
+ case GGML_OP_MUL_MAT:
2881
+ supp = ggml_hexagon_supported_mul_mat(sess, op);
2882
+ break;
2883
+
2884
+ case GGML_OP_MUL_MAT_ID:
2885
+ supp = ggml_hexagon_supported_mul_mat_id(sess, op);
2886
+ break;
2887
+
2888
+ case GGML_OP_MUL:
2889
+ case GGML_OP_ADD:
2890
+ case GGML_OP_SUB:
2891
+ supp = ggml_hexagon_supported_binary(sess, op);
2892
+ break;
2893
+
2894
+ case GGML_OP_ADD_ID:
2895
+ supp = ggml_hexagon_supported_add_id(sess, op);
2896
+ break;
2897
+
2898
+ case GGML_OP_RMS_NORM:
2899
+ case GGML_OP_SCALE:
2900
+ supp = ggml_hexagon_supported_unary(sess, op);
2901
+ break;
2902
+
2903
+ case GGML_OP_SOFT_MAX:
2904
+ supp = ggml_hexagon_supported_softmax(sess, op);
2905
+ break;
2906
+
2907
+ case GGML_OP_UNARY:
2908
+ {
2909
+ const auto unary_op = ggml_get_unary_op(op);
2910
+ if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) {
2911
+ supp = ggml_hexagon_supported_activations(sess, op);
2912
+ }
2913
+ break;
2914
+ }
2915
+ case GGML_OP_GLU:
2916
+ {
2917
+ const auto glu_op = ggml_get_glu_op(op);
2918
+ if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) {
2919
+ supp = ggml_hexagon_supported_activations(sess, op);
2920
+ }
2921
+ break;
2922
+ }
2923
+ case GGML_OP_ROPE:
2924
+ supp = ggml_hexagon_supported_rope(sess, op);
2925
+ break;
2926
+
2927
+ case GGML_OP_FLASH_ATTN_EXT:
2928
+ supp = ggml_hexagon_supported_flash_attn_ext(sess, op);
2929
+ break;
2930
+
2931
+ case GGML_OP_SET_ROWS:
2932
+ supp = ggml_hexagon_supported_set_rows(sess, op);
2933
+ break;
2934
+
2935
+ case GGML_OP_GET_ROWS:
2936
+ supp = ggml_hexagon_supported_get_rows(sess, op);
2937
+ break;
2938
+
2939
+ default:
2940
+ break;
2941
+ }
2942
+
2943
+ ggml_hexagon_dump_op_supp(sess->name, op, supp);
2944
+ return supp;
2945
+ }
2946
+
2947
+ static bool ggml_backend_hexagon_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2948
+ if (buft->iface.get_alignment != ggml_backend_hexagon_buffer_type_get_alignment) {
2949
+ return false;
2950
+ }
2951
+
2952
+ auto s0 = static_cast<ggml_hexagon_session *>(dev->context);
2953
+ auto s1 = static_cast<ggml_backend_hexagon_buffer_type_context *>(buft->context)->sess;
2954
+
2955
+ // Need session/domain-id for buffers to be compatible
2956
+ bool supp = (s0->session_id == s1->session_id);
2957
+
2958
+ HEX_VERBOSE("ggml-hex: %s device-supports-buft %s (%d)\n", s0->name.c_str(), s1->name.c_str(), (int) supp);
2959
+
2960
+ return supp;
2961
+ }
2962
+
2963
+ static ggml_backend_buffer_type_t * ggml_backend_hexagon_device_get_extra_buffers_type(ggml_backend_dev_t dev) {
2964
+ auto s0 = static_cast<ggml_hexagon_session *>(dev->context);
2965
+ HEX_VERBOSE("ggml-hex: device-get-extra-buft : %s \n", s0->name.c_str());
2966
+
2967
+ static ggml_backend_buffer_type_t bufts[2];
2968
+ bufts[0] = ggml_backend_hexagon_device_get_repack_buffer_type(dev);
2969
+ bufts[1] = NULL;
2970
+ return bufts;
2971
+ }
2972
+
2973
+ static const struct ggml_backend_device_i ggml_backend_hexagon_device_i = {
2974
+ /* .get_name = */ ggml_backend_hexagon_device_get_name,
2975
+ /* .get_description = */ ggml_backend_hexagon_device_get_description,
2976
+ /* .get_memory = */ ggml_backend_hexagon_device_get_memory,
2977
+ /* .get_type = */ ggml_backend_hexagon_device_get_type,
2978
+ /* .get_props = */ ggml_backend_hexagon_device_get_props,
2979
+ /* .init_backend = */ ggml_backend_hexagon_device_init,
2980
+ /* .get_buffer_type = */ ggml_backend_hexagon_device_get_buffer_type,
2981
+ /* .get_host_buffer_type = */ NULL, // ggml_backend_hexagon_device_get_host_buffer_type,
2982
+ /* .buffer_from_host_ptr = */ NULL, // ggml_backend_hexagon_device_buffer_from_ptr,
2983
+ /* .supports_op = */ ggml_backend_hexagon_device_supports_op,
2984
+ /* .supports_buft = */ ggml_backend_hexagon_device_supports_buft,
2985
+ /* .offload_op = */ NULL, // ggml_backend_hexagon_device_offload_op,
2986
+ /* .event_new = */ NULL,
2987
+ /* .event_free = */ NULL,
2988
+ /* .event_synchronize = */ NULL,
2989
+ };
2990
+
2991
+ //** backend registry
2992
+
2993
+ #define GGML_HEXAGON_MAX_SESSIONS 16
2994
+
2995
+ struct ggml_hexagon_registry {
2996
+ ggml_hexagon_registry(ggml_backend_reg_t reg);
2997
+ ~ggml_hexagon_registry();
2998
+
2999
+ ggml_backend_device devices[GGML_HEXAGON_MAX_SESSIONS];
3000
+ };
3001
+
3002
+ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
3003
+ GGML_LOG_INFO("ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\n", opt_ndev);
3004
+
3005
+ if (!opt_arch) {
3006
+ int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch);
3007
+ if (err != 0) {
3008
+ GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err);
3009
+ opt_arch = 73;
3010
+ }
3011
+ }
3012
+
3013
+ if (opt_arch < 75) {
3014
+ opt_ndev = 1;
3015
+ GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
3016
+ }
3017
+
3018
+ GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch);
3019
+
3020
+ // Create devices / sessions
3021
+ for (size_t i = 0; i < opt_ndev; i++) {
3022
+ devices[i].iface = ggml_backend_hexagon_device_i;
3023
+ devices[i].reg = reg;
3024
+ try {
3025
+ devices[i].context = new ggml_hexagon_session(i, &devices[i]);
3026
+ } catch (const std::exception & exc) {
3027
+ GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i);
3028
+ devices[i].context = nullptr;
3029
+ }
3030
+ }
3031
+ }
3032
+
3033
+ ggml_hexagon_registry::~ggml_hexagon_registry() {
3034
+ GGML_LOG_INFO("ggml-hex: releasing registry\n");
3035
+
3036
+ // Release devices / sessions
3037
+ for (size_t i = 0; i < opt_ndev; i++) {
3038
+ auto sess = static_cast<ggml_hexagon_session *>(devices[i].context);
3039
+ delete sess;
3040
+ }
3041
+ }
3042
+
3043
+ static const char * ggml_backend_hexagon_reg_get_name(ggml_backend_reg_t reg) {
3044
+ return "HTP";
3045
+ GGML_UNUSED(reg);
3046
+ }
3047
+
3048
+ static size_t ggml_backend_hexagon_reg_get_device_count(ggml_backend_reg_t reg) {
3049
+ return opt_ndev;
3050
+ GGML_UNUSED(reg);
3051
+ }
3052
+
3053
+ static ggml_backend_dev_t ggml_backend_hexagon_reg_get_device(ggml_backend_reg_t reg, size_t index) {
3054
+ auto hreg = static_cast<ggml_hexagon_registry *>(reg->context);
3055
+
3056
+ if (index >= opt_ndev || !hreg->devices[index].context) {
3057
+ return nullptr;
3058
+ }
3059
+
3060
+ return &hreg->devices[index];
3061
+ }
3062
+
3063
+ static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, const char * name) {
3064
+ if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) {
3065
+ ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_hexagon_device_get_extra_buffers_type;
3066
+ return (void *) fct;
3067
+ }
3068
+
3069
+ return NULL;
3070
+ }
3071
+
3072
+ static void ggml_hexagon_init(ggml_backend_reg * reg) {
3073
+ // Basic sanity checks to make sure definitions match
3074
+ static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0,
3075
+ "please update hexagon_type to match ggml_type");
3076
+ static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0,
3077
+ "please update hexagon_type to match ggml_type");
3078
+ static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
3079
+ "please update hexagon_type to match ggml_type");
3080
+
3081
+ const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
3082
+ const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF");
3083
+
3084
+ opt_verbose = str_verbose ? atoi(str_verbose) : 0;
3085
+ opt_profile = getenv("GGML_HEXAGON_PROFILE") != nullptr;
3086
+ opt_etm = getenv("GGML_HEXAGON_ETM") != nullptr;
3087
+ opt_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL") != nullptr;
3088
+
3089
+ const char * str_opmask = getenv("GGML_HEXAGON_OPMASK");
3090
+ if (str_opmask != nullptr) {
3091
+ opt_opmask = strtoul(str_opmask, NULL, 0);
3092
+ }
3093
+ opt_opsync = getenv("GGML_HEXAGON_OPSYNC") != nullptr;
3094
+
3095
+ const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
3096
+ if (str_ndev) {
3097
+ opt_ndev = strtoul(str_ndev, NULL, 0);
3098
+ if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
3099
+ opt_ndev = GGML_HEXAGON_MAX_SESSIONS;
3100
+ }
3101
+ }
3102
+
3103
+ const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
3104
+ if (str_nhvx) {
3105
+ opt_nhvx = strtoul(str_nhvx, NULL, 0);
3106
+ }
3107
+
3108
+ const char * str_arch = getenv("GGML_HEXAGON_ARCH");
3109
+ if (str_arch) {
3110
+ if (str_arch[0] == 'v') {
3111
+ str_arch++;
3112
+ }
3113
+ opt_arch = strtoul(str_arch, NULL, 0);
3114
+ }
3115
+
3116
+ opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1;
3117
+
3118
+ reg->context = new ggml_hexagon_registry(reg);
3119
+
3120
+ HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req),
3121
+ sizeof(struct htp_general_rsp));
3122
+ }
3123
+
3124
+ static const struct ggml_backend_reg_i ggml_backend_hexagon_reg_i = {
3125
+ /* .get_name = */ ggml_backend_hexagon_reg_get_name,
3126
+ /* .get_device_count = */ ggml_backend_hexagon_reg_get_device_count,
3127
+ /* .get_device = */ ggml_backend_hexagon_reg_get_device,
3128
+ /* .get_proc_address = */ ggml_backend_hexagon_get_proc_address,
3129
+ };
3130
+
3131
+ ggml_backend_reg_t ggml_backend_hexagon_reg(void) {
3132
+ static bool initialized = false;
3133
+
3134
+ static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION,
3135
+ /* .iface = */ ggml_backend_hexagon_reg_i,
3136
+ /* .context = */ NULL };
3137
+
3138
+ {
3139
+ static std::mutex mutex;
3140
+ std::lock_guard<std::mutex> lock(mutex);
3141
+ if (!initialized) {
3142
+ ggml_hexagon_init(&reg);
3143
+ }
3144
+
3145
+ initialized = true;
3146
+ }
3147
+
3148
+ return &reg;
3149
+ }
3150
+
3151
+ GGML_BACKEND_DL_IMPL(ggml_backend_hexagon_reg)