whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -0,0 +1,1179 @@
1
+ #pragma once
2
+
3
+ #include <sycl/sycl.hpp>
4
+ #include "dpct/helper.hpp"
5
+ #include "common.hpp"
6
+ #include "convert.hpp"
7
+ #include "vecdotq.hpp"
8
+
9
+ #include "ggml.h"
10
+
11
+ #include <cstdint>
12
+ #include <cmath>
13
+ #include <float.h>
14
+
15
+
16
+ #define FATTN_KQ_STRIDE 256
17
+ #define HALF_MAX_HALF sycl::half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
18
+ #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
19
+ #define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
20
+
21
+ typedef void (*fattn_kernel_t)(
22
+ const char* Q,
23
+ const char* K,
24
+ const char* V,
25
+ const char* mask,
26
+ const char* sinks,
27
+ const int* KV_max,
28
+ float* dst,
29
+ sycl::float2* dst_meta,
30
+ const float scale,
31
+ const float max_bias,
32
+ const float m0,
33
+ const float m1,
34
+ const uint32_t n_head_log2,
35
+ const float logit_softcap,
36
+ const int32_t ne00,
37
+ const sycl::uint3 ne01,
38
+ const int32_t ne02,
39
+ const int32_t ne03,
40
+ const int32_t nb01,
41
+ const int32_t nb02,
42
+ const int32_t nb03,
43
+ const int32_t ne10,
44
+ const int32_t ne11,
45
+ const int32_t ne12,
46
+ const int32_t ne13,
47
+ const int32_t nb11,
48
+ const int32_t nb12,
49
+ const int64_t nb13,
50
+ const int32_t nb21,
51
+ const int32_t nb22,
52
+ const int64_t nb23,
53
+ const int32_t ne31,
54
+ const int32_t ne32,
55
+ const int32_t ne33,
56
+ const int32_t nb31,
57
+ const int32_t nb32,
58
+ const int64_t nb33);
59
+
60
+ typedef float (*vec_dot_KQ_t)(
61
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
62
+
63
+ template <int D, int nthreads>
64
+ static __dpct_inline__ float vec_dot_fattn_vec_KQ_f16(const char * __restrict__ K_c,
65
+ const void * __restrict__ Q_v,
66
+ const int * __restrict__ Q_q8,
67
+ const void * __restrict__ Q_ds_v) {
68
+ const sycl::half2 * K_h2 = (const sycl::half2 *) K_c;
69
+ GGML_UNUSED(Q_q8);
70
+ GGML_UNUSED(Q_ds_v);
71
+
72
+ constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
73
+ constexpr int cpy_ne = cpy_nb / 4;
74
+
75
+ float sum = 0.0f;
76
+
77
+ #pragma unroll
78
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
79
+ sycl::half2 tmp[cpy_ne];
80
+ ggml_sycl_memcpy_1<sizeof(tmp)>(
81
+ tmp,
82
+ K_h2 + k_KQ_0 + (sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_id(2) % nthreads) * cpy_ne);
83
+ #pragma unroll
84
+ for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
85
+ #ifdef GGML_SYCL_F16
86
+ ggml_sycl_mad(sum, tmp[k_KQ_1] , ((const sycl::half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
87
+ #else
88
+ ggml_sycl_mad(sum, __half22float2(tmp[k_KQ_1]), ((const sycl::float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
89
+ #endif // GGML_SYCL_F16
90
+ }
91
+ }
92
+
93
+ return sum;
94
+ }
95
+
96
+ template <int D, int nthreads, int warp_size>
97
+ static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_0(const char * __restrict__ K_c,
98
+ const void * __restrict__ Q_v,
99
+ const int * __restrict__ Q_q8,
100
+ const void * __restrict__ Q_ds_v) {
101
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
102
+
103
+ const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
104
+ GGML_UNUSED(Q_v);
105
+
106
+ float sum = 0.0f;
107
+
108
+ #pragma unroll
109
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
110
+ const int k_KQ =
111
+ k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
112
+
113
+ const int ib = k_KQ / QI8_1;
114
+ const int iqs4 = k_KQ % QI4_0;
115
+ const int shift = k_KQ & (QI8_1/2);
116
+
117
+ int v;
118
+ ggml_sycl_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
119
+ v = (v >> shift) & 0x0F0F0F0F;
120
+ const int u = Q_q8[k_KQ_0/nthreads];
121
+
122
+ const int sumi = ggml_sycl_dp4a(v, u, 0);
123
+
124
+ const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
125
+ sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x() - (8/QI8_1)*Q_ds.y());
126
+ }
127
+
128
+ return sum;
129
+ }
130
+
131
+ template <int D, int nthreads , int warp_size>
132
+ static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_1(const char * __restrict__ K_c,
133
+ const void * __restrict__ Q_v,
134
+ const int * __restrict__ Q_q8,
135
+ const void * __restrict__ Q_ds_v) {
136
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
137
+ const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
138
+ GGML_UNUSED(Q_v);
139
+
140
+ float sum = 0.0f;
141
+
142
+ #pragma unroll
143
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
144
+ const int k_KQ =
145
+ k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
146
+
147
+ const int ib = k_KQ / QI8_1;
148
+ const int iqs4 = k_KQ % QI4_1;
149
+ const int shift = k_KQ & (QI8_1/2);
150
+
151
+ int v;
152
+ ggml_sycl_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
153
+ v = (v >> shift) & 0x0F0F0F0F;
154
+ const int u = Q_q8[k_KQ_0/nthreads];
155
+
156
+ const int sumi = ggml_sycl_dp4a(v, u, 0);
157
+
158
+ const sycl::float2 K_dm = (K_q4_1[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
159
+ const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
160
+
161
+ sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1;
162
+ }
163
+
164
+ return sum;
165
+ }
166
+
167
+ template <int D, int nthreads, int warp_size>
168
+ static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_0(const char * __restrict__ K_c,
169
+ const void * __restrict__ Q_v,
170
+ const int * __restrict__ Q_q8,
171
+ const void * __restrict__ Q_ds_v) {
172
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
173
+ const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
174
+ GGML_UNUSED(Q_v);
175
+
176
+ float sum = 0.0f;
177
+
178
+ #pragma unroll
179
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
180
+ const int k_KQ =
181
+ k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
182
+
183
+ const int ib = k_KQ / QI8_1;
184
+ const int iqs4 = k_KQ % QI5_0;
185
+ const int iqs8 = k_KQ % QI8_1;
186
+ const int shift = k_KQ & (QI8_1/2);
187
+
188
+ int v;
189
+ ggml_sycl_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
190
+ v = (v >> shift) & 0x0F0F0F0F;
191
+
192
+ {
193
+ int vh;
194
+ ggml_sycl_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);
195
+ vh >>= iqs8 * QI5_0;
196
+
197
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
198
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
199
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
200
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
201
+ }
202
+
203
+ const int u = Q_q8[k_KQ_0/nthreads];
204
+
205
+ const int sumi = ggml_sycl_dp4a(v, u, 0);
206
+
207
+ const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
208
+
209
+ sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x() - (16/QI8_1)*Q_ds.y());
210
+ }
211
+
212
+ return sum;
213
+ }
214
+
215
+ template <int D, int nthreads, int warp_size>
216
+ static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_1(const char * __restrict__ K_c,
217
+ const void * __restrict__ Q_v,
218
+ const int * __restrict__ Q_q8,
219
+ const void * __restrict__ Q_ds_v) {
220
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
221
+ const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
222
+ GGML_UNUSED(Q_v);
223
+
224
+ float sum = 0.0f;
225
+
226
+ #pragma unroll
227
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
228
+ const int k_KQ =
229
+ k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
230
+
231
+ const int ib = k_KQ / QI8_1;
232
+ const int iqs4 = k_KQ % QI5_1;
233
+ const int iqs8 = k_KQ % QI8_1;
234
+ const int shift = k_KQ & (QI8_1/2);
235
+
236
+ int v;
237
+ ggml_sycl_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
238
+ v = (v >> shift) & 0x0F0F0F0F;
239
+
240
+ {
241
+ int vh;
242
+ ggml_sycl_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);
243
+ vh >>= iqs8 * QI5_0;
244
+
245
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
246
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
247
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
248
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
249
+ }
250
+
251
+ const int u = Q_q8[k_KQ_0/nthreads];
252
+
253
+ const int sumi = ggml_sycl_dp4a(v, u, 0);
254
+
255
+ const sycl::float2 K_dm = (K_q5_1[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
256
+ const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
257
+
258
+ sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1;
259
+ }
260
+
261
+ return sum;
262
+ }
263
+
264
+ template <int D, int nthreads, int warp_size>
265
+ static __dpct_inline__ float vec_dot_fattn_vec_KQ_q8_0(const char * __restrict__ K_c,
266
+ const void * __restrict__ Q_v,
267
+ const int * __restrict__ Q_q8,
268
+ const void * __restrict__ Q_ds_v) {
269
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
270
+ const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
271
+ GGML_UNUSED(Q_v);
272
+
273
+ float sum = 0.0f;
274
+
275
+ #pragma unroll
276
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
277
+ const int k_KQ =
278
+ k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
279
+
280
+ const int ib = k_KQ / QI8_0;
281
+ const int iqs = k_KQ % QI8_0;
282
+
283
+ int v;
284
+ ggml_sycl_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);
285
+
286
+ const sycl::float2 * Q_ds = (const sycl::float2 *) Q_ds_v;
287
+ const float Q_d = Q_ds[k_KQ_0 / nthreads].x();
288
+
289
+ sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
290
+ }
291
+
292
+ return sum;
293
+ }
294
+
295
+ template <typename Tds, int ni, int warp_size>
296
+ static __dpct_inline__ void quantize_q8_1_to_shared(const float * __restrict__ x,
297
+ const float scale,
298
+ int * __restrict__ yq32,
299
+ void * __restrict__ yds) {
300
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
301
+
302
+ float vals[sizeof(int)] = { 0.0f };
303
+ #pragma unroll
304
+ for (int l = 0; l < int(sizeof(int)); ++l) {
305
+ vals[l] =
306
+ (ni == warp_size || item_ct1.get_local_id(2) < ni) ? scale * x[4 * item_ct1.get_local_id(2) + l] : 0.0f;
307
+ }
308
+
309
+ float amax = sycl::fabs(vals[0]);
310
+ float sum = vals[0];
311
+ #pragma unroll
312
+ for (int l = 1; l < int(sizeof(int)); ++l) {
313
+ amax = sycl::fmax(amax, sycl::fabs(vals[l]));
314
+ sum += vals[l];
315
+ }
316
+ #pragma unroll
317
+ for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
318
+ amax = sycl::fmax(
319
+ amax, dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), amax, mask));
320
+ sum += dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), sum, mask);
321
+ }
322
+
323
+ const float d = amax / 127;
324
+ int q32 = 0;
325
+ int8_t * q8 = (int8_t *) &q32;
326
+
327
+ if (d != 0.0f) {
328
+ #pragma unroll
329
+ for (int l = 0; l < int(sizeof(int)); ++l) {
330
+ q8[l] = sycl::round(vals[l] / d);
331
+ }
332
+ }
333
+
334
+ yq32[item_ct1.get_local_id(2)] = q32;
335
+ if (item_ct1.get_local_id(2) % QI8_1 == 0 && (ni == warp_size || item_ct1.get_local_id(2) < ni)) {
336
+ if (std::is_same<Tds, sycl::half2>::value) {
337
+ ((sycl::half2 *) yds)[item_ct1.get_local_id(2)/QI8_1] = make_half2(d, sum);
338
+ } else {
339
+ ((sycl::float2 *) yds)[item_ct1.get_local_id(2)/QI8_1] = make_float2(d, sum);
340
+ }
341
+ }
342
+ }
343
+
344
+ typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
345
+
346
+ template <typename T, int ne>
347
+ static __dpct_inline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
348
+ if constexpr (std::is_same_v<T, sycl::half>) {
349
+ ggml_sycl_memcpy_1<ne * sizeof(sycl::half)>(dst, (const sycl::half *) vx + i0);
350
+ } else if constexpr (std::is_same_v<T, float>) {
351
+ static_assert(ne % 2 == 0, "bad ne");
352
+ sycl::half2 tmp[ne / 2];
353
+ ggml_sycl_memcpy_1<ne * sizeof(sycl::half)>(tmp, (const sycl::half *) vx + i0);
354
+ sycl::float2 * dst_f2 = (sycl::float2 *) dst;
355
+ #pragma unroll
356
+ for (int l = 0; l < ne/2; ++l) {
357
+ dst_f2[l] = tmp[l].template convert<float, sycl::rounding_mode::automatic>();
358
+ }
359
+ } else {
360
+ static_assert(std::is_same_v<T, void>, "unsupported type");
361
+ }
362
+ }
363
+
364
+ template <typename T, int ne>
365
+ static __dpct_inline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
366
+ const block_q4_0 * x = (const block_q4_0 *) vx;
367
+
368
+ const int64_t ib = i0 / QK4_0;
369
+ const int iqs = i0 % (QK4_0/2);
370
+ const int shift = (i0 % QK4_0) / (QK4_0/2);
371
+
372
+ int q;
373
+ static_assert(ne == 2 || ne == 4, "bad ne");
374
+ ggml_sycl_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
375
+ q >>= 4*shift;
376
+ q &= 0x0F0F0F0F;
377
+ q = dpct::vectorized_binary<sycl::char4>(q, 0x08080808, dpct::sub_sat());
378
+
379
+ const int8_t * q8 = (const int8_t *) &q;
380
+
381
+ #ifdef GGML_SYCL_F16
382
+ if constexpr (std::is_same_v<T, sycl::half>) {
383
+ const sycl::half2 d = sycl::half2(x[ib].d);
384
+
385
+ #pragma unroll
386
+ for (int l0 = 0; l0 < ne; l0 += 2) {
387
+ ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]);
388
+ }
389
+ } else
390
+ #endif // GGML_SYCL_F16
391
+ if constexpr (std::is_same_v<T, float>) {
392
+ const float d = x[ib].d;
393
+
394
+ #pragma unroll
395
+ for (int l = 0; l < ne; ++l) {
396
+ ((float *) dst)[l] = d * q8[l];
397
+ }
398
+ } else {
399
+ static_assert(std::is_same_v<T, void>, "bad type");
400
+ }
401
+ }
402
+
403
+ template <typename T, int ne>
404
+ static __dpct_inline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
405
+ const block_q4_1 * x = (const block_q4_1 *) vx;
406
+
407
+ const int64_t ib = i0 / QK4_1;
408
+ const int iqs = i0 % (QK4_1/2);
409
+ const int shift = (i0 % QK4_1) / (QK4_1/2);
410
+
411
+ int q;
412
+ static_assert(ne == 2 || ne == 4, "bad ne");
413
+ ggml_sycl_memcpy_1<ne>(&q, x[ib].qs + iqs);
414
+ q >>= 4*shift;
415
+ q &= 0x0F0F0F0F;
416
+
417
+ const int8_t * q8 = (const int8_t *) &q;
418
+
419
+ #ifdef GGML_SYCL_F16
420
+ if constexpr (std::is_same_v<T, sycl::half>) {
421
+ const sycl::half2 dm = x[ib].dm;
422
+ const sycl::half2 d = sycl::half2(dm[0]);
423
+ const sycl::half2 m = sycl::half2(dm[1]);
424
+
425
+ #pragma unroll
426
+ for (int l0 = 0; l0 < ne; l0 += 2) {
427
+ ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m;
428
+ }
429
+ } else
430
+ #endif // GGML_SYCL_F16
431
+ if constexpr (std::is_same_v<T, float>) {
432
+ const sycl::float2 dm = (x[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
433
+
434
+ #pragma unroll
435
+ for (int l = 0; l < ne; ++l) {
436
+ ((float *) dst)[l] = dm.x() * q8[l] + dm.y();
437
+ }
438
+ } else {
439
+ static_assert(std::is_same_v<T, void>, "bad type");
440
+ }
441
+ }
442
+
443
+ template <typename T, int ne>
444
+ static __dpct_inline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
445
+ const block_q5_0 * x = (const block_q5_0 *) vx;
446
+
447
+ const int64_t ib = i0 / QK5_0;
448
+ const int idq = i0 % QK5_0;
449
+ const int iqs = i0 % (QK5_0/2);
450
+ const int shift = (i0 % QK5_0) / (QK5_0/2);
451
+
452
+ int q;
453
+ static_assert(ne == 2 || ne == 4, "bad ne");
454
+ ggml_sycl_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
455
+ q >>= 4*shift;
456
+ q &= 0x0F0F0F0F;
457
+
458
+ {
459
+ int qh;
460
+ ggml_sycl_memcpy_1<ne, 2>(&qh, x[ib].qh);
461
+ #pragma unroll
462
+ for (int l = 0; l < ne; ++l) {
463
+ q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
464
+ }
465
+ }
466
+
467
+ q = dpct::vectorized_binary<sycl::char4>(q, 0x10101010, dpct::sub_sat());
468
+
469
+ const int8_t * q8 = (const int8_t *) &q;
470
+
471
+ #ifdef GGML_SYCL_F16
472
+ if constexpr (std::is_same_v<T, sycl::half>) {
473
+ const sycl::half2 d = sycl::half2(x[ib].d);
474
+
475
+ #pragma unroll
476
+ for (int l0 = 0; l0 < ne; l0 += 2) {
477
+ ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]);
478
+ }
479
+ } else
480
+ #endif // GGML_SYCL_F16
481
+ if constexpr (std::is_same_v<T, float>) {
482
+ const float d = x[ib].d;
483
+
484
+ #pragma unroll
485
+ for (int l = 0; l < ne; ++l) {
486
+ ((float *) dst)[l] = d * q8[l];
487
+ }
488
+ } else {
489
+ static_assert(std::is_same_v<T, void>, "bad type");
490
+ }
491
+ }
492
+
493
+ template <typename T, int ne>
494
+ static __dpct_inline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
495
+ const block_q5_1 * x = (const block_q5_1 *) vx;
496
+
497
+ const int64_t ib = i0 / QK5_1;
498
+ const int idq = i0 % QK5_1;
499
+ const int iqs = i0 % (QK5_1/2);
500
+ const int shift = (i0 % QK5_1) / (QK5_1/2);
501
+
502
+ int q;
503
+ static_assert(ne == 2 || ne == 4, "bad ne");
504
+ ggml_sycl_memcpy_1<ne>(&q, x[ib].qs + iqs);
505
+ q >>= 4*shift;
506
+ q &= 0x0F0F0F0F;
507
+
508
+ {
509
+ int qh;
510
+ ggml_sycl_memcpy_1<ne>(&qh, x[ib].qh);
511
+ #pragma unroll
512
+ for (int l = 0; l < ne; ++l) {
513
+ q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
514
+ }
515
+ }
516
+
517
+ const int8_t * q8 = (const int8_t *) &q;
518
+
519
+ #ifdef GGML_SYCL_F16
520
+ if constexpr (std::is_same_v<T, sycl::half>) {
521
+ const sycl::half2 dm = x[ib].dm;
522
+ const sycl::half2 d = sycl::half2(dm[0]);
523
+ const sycl::half2 m = sycl::half2(dm[1]);
524
+
525
+ #pragma unroll
526
+ for (int l0 = 0; l0 < ne; l0 += 2) {
527
+ ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m;
528
+ }
529
+ } else
530
+ #endif // GGML_SYCL_F16
531
+ if constexpr (std::is_same_v<T, float>) {
532
+ const sycl::float2 dm = (x[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
533
+
534
+ #pragma unroll
535
+ for (int l = 0; l < ne; ++l) {
536
+ ((float *) dst)[l] = dm.x() * q8[l] + dm.y();
537
+ }
538
+ } else {
539
+ static_assert(std::is_same_v<T, void>, "bad type");
540
+ }
541
+ }
542
+
543
+ template <typename T, int ne>
544
+ static __dpct_inline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
545
+ const block_q8_0 * x = (const block_q8_0 *) vx;
546
+
547
+ const int64_t ib = i0 / QK8_0;
548
+ const int iqs = i0 % QK8_0;
549
+
550
+ static_assert(ne % 2 == 0, "bad ne");
551
+ int8_t qs[ne];
552
+ ggml_sycl_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);
553
+
554
+ #ifdef GGML_SYCL_F16
555
+ if constexpr (std::is_same<T, sycl::half>::value) {
556
+ const sycl::half2 d = sycl::half2(x[ib].d);
557
+
558
+ #pragma unroll
559
+ for (int l0 = 0; l0 < ne; l0 += 2) {
560
+ ((sycl::half2 *) dst)[l0 / 2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
561
+ }
562
+ } else
563
+ #endif // GGML_SYCL_F16
564
+ if constexpr (std::is_same<T, float>::value) {
565
+ const float d = x[ib].d;
566
+
567
+ #pragma unroll
568
+ for (int l = 0; l < ne; ++l) {
569
+ ((float *) dst)[l] = d * qs[l];
570
+ }
571
+ } else {
572
+ static_assert(std::is_same_v<T, void>, "unsupported type");
573
+ }
574
+ }
575
+
576
+ template <int type_K, int D, int nthreads, int warp_size>
577
+ constexpr vec_dot_KQ_t get_vec_dot_KQ() {
578
+ if constexpr (type_K == GGML_TYPE_F16) {
579
+ return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
580
+ } else if constexpr (type_K == GGML_TYPE_Q4_0) {
581
+ return vec_dot_fattn_vec_KQ_q4_0<D, nthreads, warp_size>;
582
+ } else if constexpr (type_K == GGML_TYPE_Q4_1) {
583
+ return vec_dot_fattn_vec_KQ_q4_1<D, nthreads, warp_size>;
584
+ } else if constexpr (type_K == GGML_TYPE_Q5_0) {
585
+ return vec_dot_fattn_vec_KQ_q5_0<D, nthreads, warp_size>;
586
+ } else if constexpr (type_K == GGML_TYPE_Q5_1) {
587
+ return vec_dot_fattn_vec_KQ_q5_1<D, nthreads, warp_size>;
588
+ } else if constexpr (type_K == GGML_TYPE_Q8_0) {
589
+ return vec_dot_fattn_vec_KQ_q8_0<D, nthreads, warp_size>;
590
+ } else {
591
+ static_assert(type_K == -1, "bad type");
592
+ return nullptr;
593
+ }
594
+ }
595
+
596
+ template <int type_V, typename T, int ne>
597
+ constexpr dequantize_V_t get_dequantize_V() {
598
+ if constexpr (type_V == GGML_TYPE_F16) {
599
+ return dequantize_V_f16<T, ne>;
600
+ } else if constexpr (type_V == GGML_TYPE_Q4_0) {
601
+ return dequantize_V_q4_0<T, ne>;
602
+ } else if constexpr (type_V == GGML_TYPE_Q4_1) {
603
+ return dequantize_V_q4_1<T, ne>;
604
+ } else if constexpr (type_V == GGML_TYPE_Q5_0) {
605
+ return dequantize_V_q5_0<T, ne>;
606
+ } else if constexpr (type_V == GGML_TYPE_Q5_1) {
607
+ return dequantize_V_q5_1<T, ne>;
608
+ } else if constexpr (type_V == GGML_TYPE_Q8_0) {
609
+ return dequantize_V_q8_0<T, ne>;
610
+ } else {
611
+ static_assert(type_V == -1, "bad type");
612
+ return nullptr;
613
+ }
614
+ }
615
+
616
+ template <int ncols1, int warp_size>
617
+ static void flash_attn_mask_to_KV_max(const sycl::half2 * __restrict__ mask,
618
+ int * __restrict__ KV_max,
619
+ const int ne30,
620
+ const int s31,
621
+ const int s33,
622
+ int * buf_iw) {
623
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
624
+ const int ne31 = item_ct1.get_group_range(2);
625
+ const int tid = item_ct1.get_local_id(2);
626
+ const int sequence = item_ct1.get_group(1);
627
+ const int jt = item_ct1.get_group(2);
628
+
629
+ mask += sequence*s33 + jt*ncols1*s31;
630
+
631
+ if (tid < warp_size) {
632
+ buf_iw[tid] = 1;
633
+ }
634
+ item_ct1.barrier(sycl::access::fence_space::local_space);
635
+
636
+ int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
637
+ for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
638
+ int all_inf = 1;
639
+
640
+ #pragma unroll
641
+ for (int j = 0; j < ncols1; ++j) {
642
+ const sycl::float2 tmp =
643
+ mask[j * s31 + KV_max_sj / 2 + tid].template convert<float, sycl::rounding_mode::automatic>();
644
+ all_inf = all_inf && int(sycl::isinf((float) (tmp.x()))) && int(sycl::isinf((float) (tmp.y())));
645
+ }
646
+
647
+ all_inf = warp_reduce_all<warp_size>(all_inf);
648
+ if (tid % warp_size == 0) {
649
+ buf_iw[tid / warp_size] = all_inf;
650
+ }
651
+ item_ct1.barrier(sycl::access::fence_space::local_space);
652
+ all_inf = buf_iw[tid % warp_size];
653
+ item_ct1.barrier(sycl::access::fence_space::local_space);
654
+ all_inf = warp_reduce_all<warp_size>(all_inf);
655
+
656
+ if (!all_inf) {
657
+ break;
658
+ }
659
+ }
660
+
661
+ // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
662
+ // If the break was triggered it's the lower edge of the tile with the first non-masked values.
663
+ // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
664
+ KV_max_sj += FATTN_KQ_STRIDE;
665
+
666
+ if (item_ct1.get_local_id(2) != 0) {
667
+ return;
668
+ }
669
+
670
+ KV_max[sequence*ne31 + jt] = KV_max_sj;
671
+ }
672
+
673
+ template <int D, int ncols1, int ncols2> // D == head size
674
+
675
+ static void flash_attn_stream_k_fixup(float * __restrict__ dst,
676
+ const sycl::float2 * __restrict__ dst_fixup,
677
+ const int ne01,
678
+ const int ne02,
679
+ const int ne03,
680
+ const int ne11,
681
+ const int ne12,
682
+ const int nbatch_fa) {
683
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
684
+ constexpr int ncols = ncols1 * ncols2;
685
+
686
+ const int bidx0 = item_ct1.get_group(2);
687
+ const int j = item_ct1.get_group(1);
688
+ const int c = item_ct1.get_group(0);
689
+ const int jc = j*ncols2 + c;
690
+ const int tid = item_ct1.get_local_id(2);
691
+
692
+ const float * dst_fixup_data = ((const float *) dst_fixup) + item_ct1.get_group_range(2) * (2 * 2 * ncols);
693
+
694
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
695
+
696
+ const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
697
+ const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
698
+ const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
699
+
700
+ const int kbc0 = int64_t(bidx0 + 0) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
701
+ const int kbc0_stop =
702
+ int64_t(bidx0 + 1) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
703
+
704
+ const bool did_not_have_any_data = kbc0 == kbc0_stop;
705
+ const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
706
+ const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
707
+ if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
708
+ return;
709
+ }
710
+
711
+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
712
+ const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
713
+ const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
714
+ const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
715
+ const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
716
+
717
+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
718
+
719
+ if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
720
+ return;
721
+ }
722
+
723
+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
724
+
725
+ // Load the partial result that needs a fixup:
726
+ float dst_val = 0.0f;
727
+ float max_val = 0.0f;
728
+ float rowsum = 0.0f;
729
+ {
730
+ dst_val = *dst;
731
+
732
+ const sycl::float2 tmp = dst_fixup[bidx0 * ncols + jc];
733
+ max_val = tmp.x();
734
+ rowsum = tmp.y();
735
+ }
736
+
737
+ // Iterate over previous blocks and compute the combined results.
738
+ // All SYCL blocks that get here must have a previous block that needs a fixup.
739
+ int bidx = bidx0 - 1;
740
+ int kbc_stop = kbc0;
741
+ while(true) {
742
+ const int kbc = int64_t(bidx) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
743
+ if (kbc == kbc_stop) { // Did not have any data.
744
+ bidx--;
745
+ kbc_stop = kbc;
746
+ continue;
747
+ }
748
+
749
+ const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
750
+
751
+ const sycl::float2 tmp = dst_fixup[(item_ct1.get_group_range(2) + bidx) * ncols + jc];
752
+
753
+ // Scale the current and new value accumulators depending on the max. values.
754
+ const float max_val_new = sycl::fmax(max_val, tmp.x());
755
+
756
+ const float diff_val = max_val - max_val_new;
757
+ const float diff_add = tmp.x() - max_val_new;
758
+
759
+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_val) : 0.0f;
760
+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_add) : 0.0f;
761
+
762
+ dst_val = scale_val*dst_val + scale_add*dst_add;
763
+ rowsum = scale_val * rowsum + scale_add * tmp.y();
764
+
765
+ max_val = max_val_new;
766
+
767
+ // If this block started in a previous tile we are done and don't need to combine additional partial results.
768
+ if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
769
+ break;
770
+ }
771
+ bidx--;
772
+ kbc_stop = kbc;
773
+ }
774
+
775
+ // Write back final result:
776
+ *dst = dst_val / rowsum;
777
+ }
778
+
779
+ template <int D> // D == head size
780
+
781
+ static void flash_attn_combine_results(const float * __restrict__ VKQ_parts,
782
+ const sycl::float2 * __restrict__ VKQ_meta,
783
+ float * __restrict__ dst,
784
+ const int parallel_blocks,
785
+ uint8_t * dpct_local) {
786
+ // Dimension 0: threadIdx.x
787
+ // Dimension 1: blockIdx.x
788
+ // Dimension 2: blockIdx.y
789
+ // Dimension 3: blockIdx.z
790
+ // Memory layout is permuted with [0, 2, 1, 3]
791
+
792
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
793
+ const int ne01 = item_ct1.get_group_range(2);
794
+ const int ne02 = item_ct1.get_group_range(1);
795
+
796
+ const int col = item_ct1.get_group(2);
797
+ const int head = item_ct1.get_group(1);
798
+ const int sequence = item_ct1.get_group(0);
799
+
800
+ const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
801
+
802
+ VKQ_parts += j_dst_unrolled * parallel_blocks*D;
803
+ VKQ_meta += j_dst_unrolled * parallel_blocks;
804
+ dst += j_dst_unrolled * D;
805
+
806
+ const int tid = item_ct1.get_local_id(2);
807
+ __builtin_assume(tid < D);
808
+
809
+ auto meta = (sycl::float2 *) dpct_local;
810
+ for (int i = tid; i < 2*parallel_blocks; i += D) {
811
+ ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
812
+ }
813
+
814
+ item_ct1.barrier(sycl::access::fence_space::local_space);
815
+
816
+ float kqmax = meta[0].x();
817
+ for (int l = 1; l < parallel_blocks; ++l) {
818
+ kqmax = sycl::max(kqmax, meta[l].x());
819
+ }
820
+
821
+ float VKQ_numerator = 0.0f;
822
+ float VKQ_denominator = 0.0f;
823
+ for (int l = 0; l < parallel_blocks; ++l) {
824
+ const float KQ_max_scale = sycl::native::exp(meta[l].x() - kqmax);
825
+
826
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
827
+ VKQ_denominator += KQ_max_scale * meta[l].y();
828
+ }
829
+
830
+ dst[tid] = VKQ_numerator / VKQ_denominator;
831
+ }
832
+
833
+ template <fattn_kernel_t fattn_kernel, int warp_size>
834
+ static void lauch_kernel(
835
+ dpct::dim3 group_range,
836
+ dpct::dim3 local_range,
837
+ queue_ptr q,
838
+ unsigned int local_mem_size,
839
+ const char* __restrict__ Q,
840
+ const char* __restrict__ K,
841
+ const char* __restrict__ V,
842
+ const char* __restrict__ mask,
843
+ const char* __restrict__ sinks,
844
+ const int* __restrict__ KV_max,
845
+ float* __restrict__ dst,
846
+ sycl::float2* __restrict__ dst_meta,
847
+ const float scale,
848
+ const float max_bias,
849
+ const float m0,
850
+ const float m1,
851
+ const uint32_t n_head_log2,
852
+ const float logit_softcap,
853
+ const int32_t ne00,
854
+ const sycl::uint3 ne01,
855
+ const int32_t ne02,
856
+ const int32_t ne03,
857
+ const int32_t nb01,
858
+ const int32_t nb02,
859
+ const int32_t nb03,
860
+ const int32_t ne10,
861
+ const int32_t ne11,
862
+ const int32_t ne12,
863
+ const int32_t ne13,
864
+ const int32_t nb11,
865
+ const int32_t nb12,
866
+ const int64_t nb13,
867
+ const int32_t nb21,
868
+ const int32_t nb22,
869
+ const int64_t nb23,
870
+ const int32_t ne31,
871
+ const int32_t ne32,
872
+ const int32_t ne33,
873
+ const int32_t nb31,
874
+ const int32_t nb32,
875
+ const int64_t nb33) {
876
+ GGML_UNUSED(local_mem_size);
877
+ q->submit([&](sycl::handler &cgh) {
878
+ cgh.parallel_for(
879
+ sycl::nd_range<3>(
880
+ static_cast<sycl::range<3>>(group_range * local_range),
881
+ static_cast<sycl::range<3>>(local_range)),
882
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
883
+ GGML_UNUSED(item_ct1);
884
+ fattn_kernel(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
885
+ max_bias, m0, m1, n_head_log2, logit_softcap, ne00,
886
+ ne01, ne02, ne03, nb01, nb02, nb03, ne10, ne11,
887
+ ne12, ne13, nb11, nb12, nb13, nb21, nb22, nb23,
888
+ ne31, ne32, ne33, nb31, nb32, nb33);
889
+ });
890
+ });
891
+ }
892
+
893
+ template <int DV, int ncols1, int ncols2, fattn_kernel_t fattn_kernel, int warp_size>
894
+ void launch_fattn(
895
+ ggml_backend_sycl_context & ctx, ggml_tensor * dst, const int nwarps, const size_t nbytes_shared,
896
+ const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k) {
897
+
898
+ constexpr int ncols = ncols1 * ncols2;
899
+
900
+ const ggml_tensor * Q = dst->src[0];
901
+ const ggml_tensor * K = dst->src[1];
902
+ const ggml_tensor * V = dst->src[2];
903
+
904
+ const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
905
+
906
+ const ggml_tensor * mask = dst->src[3];
907
+ const ggml_tensor * sinks = dst->src[4];
908
+
909
+ ggml_tensor * KQV = dst;
910
+
911
+ GGML_ASSERT(Q->type == GGML_TYPE_F32);
912
+ GGML_ASSERT(KQV->type == GGML_TYPE_F32);
913
+
914
+ GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
915
+ GGML_ASSERT(K->nb[0] == ggml_element_size(K));
916
+ GGML_ASSERT(V->nb[0] == ggml_element_size(V));
917
+
918
+ GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
919
+
920
+ ggml_sycl_pool & pool = ctx.pool();
921
+ dpct::queue_ptr main_stream = ctx.stream();
922
+ const int id = ggml_sycl_get_device();
923
+ const int nsm = ggml_sycl_info().devices[id].nsm;
924
+
925
+ ggml_sycl_pool_alloc<sycl::half> K_f16(pool);
926
+ ggml_sycl_pool_alloc<sycl::half> V_f16(pool);
927
+ ggml_sycl_pool_alloc<int> KV_max(pool);
928
+ ggml_sycl_pool_alloc<float> dst_tmp(pool);
929
+ ggml_sycl_pool_alloc<sycl::float2> dst_tmp_meta(pool);
930
+
931
+ const char * K_data = (const char *) K->data;
932
+ size_t nb11 = K->nb[1];
933
+ size_t nb12 = K->nb[2];
934
+ size_t nb13 = K->nb[3];
935
+
936
+ const char * V_data = (const char *) V->data;
937
+ size_t nb21 = V->nb[1];
938
+ size_t nb22 = V->nb[2];
939
+ size_t nb23 = V->nb[3];
940
+
941
+ if (need_f16_K && K->type != GGML_TYPE_F16) {
942
+ const size_t bs = ggml_blck_size(K->type);
943
+ const size_t ts = ggml_type_size(K->type);
944
+
945
+ K_f16.alloc(ggml_nelements(K));
946
+ if (ggml_is_contiguously_allocated(K)) {
947
+ to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(K->type, dst);
948
+ to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
949
+
950
+ nb11 = nb11 * bs * sizeof(sycl::half) / ts;
951
+ nb12 = nb12 * bs * sizeof(sycl::half) / ts;
952
+ nb13 = nb13 * bs * sizeof(sycl::half) / ts;
953
+ } else {
954
+ GGML_ASSERT(K->nb[0] == ts);
955
+ to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(K->type);
956
+ const int64_t s01 = nb11 / ts;
957
+ const int64_t s02 = nb12 / ts;
958
+ const int64_t s03 = nb13 / ts;
959
+ to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
960
+
961
+ nb11 = K->ne[0] * sizeof(sycl::half);
962
+ nb12 = K->ne[1] * nb11;
963
+ nb13 = K->ne[2] * nb12;
964
+ }
965
+ K_data = (char *) K_f16.ptr;
966
+ }
967
+
968
+ if (need_f16_V && V->type != GGML_TYPE_F16) {
969
+ if (V_is_K_view) {
970
+ V_data = K_data;
971
+ nb21 = nb11;
972
+ nb22 = nb12;
973
+ nb23 = nb13;
974
+ } else {
975
+ const size_t bs = ggml_blck_size(V->type);
976
+ const size_t ts = ggml_type_size(V->type);
977
+
978
+ V_f16.alloc(ggml_nelements(V));
979
+ if (ggml_is_contiguously_allocated(V)) {
980
+ to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(V->type, dst);
981
+ to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
982
+ V_data = (char *) V_f16.ptr;
983
+
984
+ nb21 = nb21 * bs * sizeof(sycl::half) / ts;
985
+ nb22 = nb22 * bs * sizeof(sycl::half) / ts;
986
+ nb23 = nb23 * bs * sizeof(sycl::half) / ts;
987
+ } else {
988
+ GGML_ASSERT(V->nb[0] == ts);
989
+ to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(V->type);
990
+ const int64_t s01 = nb21 / ts;
991
+ const int64_t s02 = nb22 / ts;
992
+ const int64_t s03 = nb23 / ts;
993
+ to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
994
+
995
+ nb21 = V->ne[0] * sizeof(sycl::half);
996
+ nb22 = V->ne[1] * nb21;
997
+ nb23 = V->ne[2] * nb22;
998
+ }
999
+ V_data = (char *) V_f16.ptr;
1000
+ }
1001
+ }
1002
+
1003
+ const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
1004
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
1005
+ const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
1006
+ const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
1007
+
1008
+ // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
1009
+ // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
1010
+ // multiple sequences of possibly different lengths.
1011
+ if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
1012
+ const int s31 = mask->nb[1] / sizeof(sycl::half2);
1013
+ const int s33 = mask->nb[3] / sizeof(sycl::half2);
1014
+
1015
+ const dpct::dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
1016
+ const dpct::dim3 block_dim_KV_max(FATTN_KQ_STRIDE / 2, 1, 1);
1017
+
1018
+ const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
1019
+ const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
1020
+
1021
+ KV_max.alloc(ne_KV_max);
1022
+ {
1023
+ dpct::has_capability_or_fail(main_stream->get_device(), { sycl::aspect::fp16 });
1024
+
1025
+ main_stream->submit([&](sycl::handler & cgh) {
1026
+ sycl::local_accessor<int, 1> buf_iw_acc_ct1(sycl::range<1>(warp_size), cgh);
1027
+
1028
+ auto mask_data_ct0 = (const sycl::half2 *) mask->data;
1029
+ auto KV_max_ptr_ct1 = KV_max.ptr;
1030
+
1031
+ cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max),
1032
+ [=](sycl::nd_item<3> item_ct1) {
1033
+ GGML_UNUSED(item_ct1);
1034
+ flash_attn_mask_to_KV_max<ncols1, warp_size>(
1035
+ mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33,
1036
+ buf_iw_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
1037
+ });
1038
+ });
1039
+ }
1040
+ SYCL_CHECK(0);
1041
+ }
1042
+
1043
+ const dpct::dim3 block_dim(warp_size, nwarps, 1);
1044
+
1045
+ // Max. number of active blocks limited by occupancy.
1046
+ int max_blocks_per_sm = ggml_sycl_info().devices[id].max_wg_per_cu;
1047
+ int parallel_blocks = max_blocks_per_sm;
1048
+ dpct::dim3 blocks_num;
1049
+ if (stream_k) {
1050
+ // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
1051
+ const int max_blocks = max_blocks_per_sm*nsm;
1052
+ const int nblocks_stream_k = max_blocks;
1053
+ const bool use_stream_k = true;
1054
+
1055
+ blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
1056
+ blocks_num.y = 1;
1057
+ blocks_num.z = 1;
1058
+
1059
+ if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
1060
+ dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
1061
+ }
1062
+ } else {
1063
+ const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
1064
+
1065
+ // parallel_blocks must not be larger than what the tensor size allows:
1066
+ parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
1067
+ // todo fix the hard code change
1068
+ // parallel_blocks = ntiles_KQ;
1069
+
1070
+ // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
1071
+ // Test whether parallel_blocks can be set to a higher value for better efficiency.
1072
+ const int blocks_per_wave = nsm * max_blocks_per_sm;
1073
+ int nwaves_best = 0;
1074
+ int efficiency_percent_best = 0;
1075
+ for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
1076
+ const int nblocks_total = ntiles_total * parallel_blocks_test;
1077
+ const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
1078
+ const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
1079
+
1080
+ // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
1081
+ if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
1082
+ break;
1083
+ }
1084
+
1085
+ if (efficiency_percent > efficiency_percent_best) {
1086
+ nwaves_best = nwaves;
1087
+ efficiency_percent_best = efficiency_percent;
1088
+ parallel_blocks = parallel_blocks_test;
1089
+ }
1090
+ }
1091
+
1092
+ blocks_num.x = ntiles_x;
1093
+ blocks_num.y = parallel_blocks;
1094
+ blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
1095
+
1096
+ if (parallel_blocks > 1) {
1097
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
1098
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
1099
+ }
1100
+ }
1101
+
1102
+ float scale = 1.0f;
1103
+ float max_bias = 0.0f;
1104
+ float logit_softcap = 0.0f;
1105
+
1106
+ memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
1107
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
1108
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
1109
+
1110
+ if (logit_softcap != 0.0f) {
1111
+ scale /= logit_softcap;
1112
+ }
1113
+
1114
+ const uint32_t n_head = Q->ne[2];
1115
+ const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
1116
+
1117
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1118
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1119
+
1120
+ // TODO other tensor dimensions after removal of WMMA kernel:
1121
+ const sycl::uint3 ne01 = init_fastdiv_values(Q->ne[1]);
1122
+
1123
+ GGML_ASSERT(block_dim.x % warp_size == 0);
1124
+
1125
+ lauch_kernel<fattn_kernel, warp_size>(
1126
+ blocks_num, block_dim, main_stream, (unsigned int) nbytes_shared, (const char *) Q->data, K_data, V_data,
1127
+ mask ? ((const char *) mask->data) : nullptr, sinks ? ((const char *) sinks->data) : nullptr, KV_max.ptr,
1128
+ !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, (sycl::float2 *)dst_tmp_meta.ptr, scale, max_bias, m0, m1,
1129
+ n_head_log2, logit_softcap, Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0],
1130
+ K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0,
1131
+ mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
1132
+ mask ? mask->nb[3] : 0);
1133
+ SYCL_CHECK(0);
1134
+
1135
+ if (stream_k) {
1136
+ if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
1137
+ const dpct::dim3 block_dim_combine(DV, 1, 1);
1138
+ const dpct::dim3 blocks_num_combine = { blocks_num.x, ncols1, ncols2 };
1139
+
1140
+ main_stream->submit([&](sycl::handler & cgh) {
1141
+ auto KQV_data_ct0 = (float *) KQV->data;
1142
+ auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr;
1143
+ auto Q_ne_ct2 = Q->ne[1];
1144
+ auto Q_ne_ct3 = Q->ne[2];
1145
+ auto Q_ne_ct4 = Q->ne[3];
1146
+ auto K_ne_ct5 = K->ne[1];
1147
+ auto K_ne_ct6 = K->ne[2];
1148
+
1149
+ cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
1150
+ [=](sycl::nd_item<3> item_ct1) {
1151
+ GGML_UNUSED(item_ct1);
1152
+ flash_attn_stream_k_fixup<DV, ncols1, ncols2>(KQV_data_ct0, dst_tmp_meta_ptr_ct1,
1153
+ Q_ne_ct2, Q_ne_ct3, Q_ne_ct4,
1154
+ K_ne_ct5, K_ne_ct6, nbatch_fa);
1155
+ });
1156
+ });
1157
+ }
1158
+ } else if (parallel_blocks > 1) {
1159
+ const dpct::dim3 block_dim_combine(DV, 1, 1);
1160
+ const dpct::dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
1161
+ const size_t nbytes_shared_combine = parallel_blocks * sizeof(sycl::float2);
1162
+ main_stream->submit([&](sycl::handler & cgh) {
1163
+ sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(nbytes_shared_combine), cgh);
1164
+
1165
+ auto dst_tmp_ptr_ct0 = dst_tmp.ptr;
1166
+ auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr;
1167
+ auto KQV_data_ct2 = (float *) KQV->data;
1168
+
1169
+ cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
1170
+ [=](sycl::nd_item<3> item_ct1) {
1171
+ GGML_UNUSED(item_ct1);
1172
+ flash_attn_combine_results<DV>(
1173
+ dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks,
1174
+ dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
1175
+ });
1176
+ });
1177
+ }
1178
+ SYCL_CHECK(0);
1179
+ }