whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -15,18 +15,9 @@
15
15
 
16
16
  #include <sycl/sycl.hpp>
17
17
  #include <sycl/half_type.hpp>
18
- #include <syclcompat/math.hpp>
19
- #include <map>
20
-
21
- #ifdef GGML_SYCL_USE_INTEL_ONEMKL
22
18
  #include <oneapi/mkl.hpp>
23
- // Allow to use the same namespace for Intel oneMKL and oneMath
24
- namespace oneapi {
25
- namespace math = mkl;
26
- }
27
- #else
28
- #include <oneapi/math.hpp>
29
- #endif
19
+
20
+ #include <map>
30
21
 
31
22
  #include "ggml.h"
32
23
 
@@ -92,32 +83,13 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
92
83
  }
93
84
 
94
85
  template <typename Ts> struct matrix_info_t {
95
- oneapi::math::transpose transpose_info[2];
86
+ oneapi::mkl::transpose transpose_info[2];
96
87
  Ts value_info[2];
97
88
  std::int64_t size_info[3];
98
89
  std::int64_t ld_info[3];
99
90
  std::int64_t groupsize_info;
100
91
  };
101
92
 
102
- inline auto get_onemath_backend(sycl::queue& queue)
103
- #if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
104
- -> sycl::queue&
105
- #endif
106
- {
107
- // If the backend is known at compile-time, use oneMath backend_selector to use
108
- // compile-time dispatching and avoid the need to dlopen libraries. Otherwise
109
- // fallback to runtime dispatching.
110
- #if defined(GGML_SYCL_NVIDIA)
111
- return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
112
- #elif defined(GGML_SYCL_AMD)
113
- return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
114
- #elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
115
- return queue;
116
- #else
117
- static_assert(false, "Unsupported backend");
118
- #endif
119
- }
120
-
121
93
  namespace dpct
122
94
  {
123
95
  typedef sycl::queue *queue_ptr;
@@ -277,6 +249,26 @@ namespace dpct
277
249
 
278
250
  } // namespace detail
279
251
 
252
+ // COPY from DPCT head files
253
+ /// dim3 is used to store 3 component dimensions.
254
+ class dim3 {
255
+ public:
256
+ unsigned x, y, z;
257
+
258
+ constexpr dim3(unsigned x = 1, unsigned y = 1, unsigned z = 1)
259
+ : x(x), y(y), z(z) {}
260
+
261
+ dim3(const sycl::id<3> &r) : dim3(r[2], r[1], r[0]) {}
262
+
263
+ operator sycl::range<3>() const { return sycl::range<3>(z, y, x); }
264
+ }; // namespace dim3
265
+
266
+ inline dim3 operator*(const dim3 &a, const dim3 &b) {
267
+ return dim3{a.x * b.x, a.y * b.y, a.z * b.z};
268
+ }
269
+ // COPY from DPCT head files
270
+
271
+
280
272
  /// Pitched 2D/3D memory data.
281
273
  class pitched_data
282
274
  {
@@ -1715,7 +1707,7 @@ namespace dpct
1715
1707
  namespace detail
1716
1708
  {
1717
1709
  template <class Ta, class Tb, class Tc, class Ts>
1718
- inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
1710
+ inline void gemm_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
1719
1711
  int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
1720
1712
  const void * beta, void * c, int ldc) {
1721
1713
  Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
@@ -1723,7 +1715,7 @@ namespace dpct
1723
1715
  auto data_a = get_memory<const Ta>(a);
1724
1716
  auto data_b = get_memory<const Tb>(b);
1725
1717
  auto data_c = get_memory<Tc>(c);
1726
- oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
1718
+ oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a,
1727
1719
  lda, data_b, ldb, beta_value, data_c, ldc);
1728
1720
  }
1729
1721
 
@@ -1755,7 +1747,7 @@ namespace dpct
1755
1747
  };
1756
1748
 
1757
1749
  template <class Ta, class Tb, class Tc, class Ts>
1758
- inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
1750
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1759
1751
  int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1760
1752
  int ldb, const void * beta, void ** c, int ldc, int batch_size,
1761
1753
  matrix_info_t<float> * matrix_info) {
@@ -1774,8 +1766,8 @@ namespace dpct
1774
1766
  matrix_info->ld_info[2] = ldc;
1775
1767
  matrix_info->groupsize_info = batch_size;
1776
1768
 
1777
- sycl::event e = oneapi::math::blas::column_major::gemm_batch(
1778
- get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
1769
+ sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1770
+ q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
1779
1771
  matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
1780
1772
  reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
1781
1773
  reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
@@ -1784,7 +1776,7 @@ namespace dpct
1784
1776
  }
1785
1777
 
1786
1778
  template <class Ta, class Tb, class Tc, class Ts>
1787
- inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
1779
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1788
1780
  int m, int n, int k, const void * alpha, const void * a, int lda,
1789
1781
  long long int stride_a, const void * b, int ldb, long long int stride_b,
1790
1782
  const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
@@ -1793,7 +1785,7 @@ namespace dpct
1793
1785
  auto data_a = get_memory<const Ta>(a);
1794
1786
  auto data_b = get_memory<const Tb>(b);
1795
1787
  auto data_c = get_memory<Tc>(c);
1796
- oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
1788
+ oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value,
1797
1789
  data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
1798
1790
  data_c, ldc, stride_c, batch_size);
1799
1791
  }
@@ -1840,10 +1832,31 @@ namespace dpct
1840
1832
  : id);
1841
1833
  }
1842
1834
 
1835
+ template <typename T1, typename T2>
1836
+ using dot_product_acc_t = std::conditional_t<
1837
+ std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
1838
+ uint32_t,
1839
+ int32_t>;
1840
+
1841
+ template <typename T>
1842
+ sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val) {
1843
+ return sycl::vec<T, 1>(val)
1844
+ .template as<sycl::vec<
1845
+ std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>,
1846
+ 4>>()
1847
+ .template convert<T>();
1848
+ }
1849
+
1843
1850
  template <typename T1, typename T2, typename T3>
1844
- inline auto dp4a(T1 a, T2 b, T3 c)
1845
- {
1846
- return syclcompat::dp4a(a, b, c);
1851
+ inline auto dp4a(T1 a, T2 b, T3 c) {
1852
+ dot_product_acc_t<T1, T2> res = c;
1853
+ auto va = extract_and_sign_or_zero_extend4(a);
1854
+ auto vb = extract_and_sign_or_zero_extend4(b);
1855
+ res += va[0] * vb[0];
1856
+ res += va[1] * vb[1];
1857
+ res += va[2] * vb[2];
1858
+ res += va[3] * vb[3];
1859
+ return res;
1847
1860
  }
1848
1861
 
1849
1862
  struct sub_sat
@@ -2259,7 +2272,7 @@ namespace dpct
2259
2272
  sycl::range<3>(x, y, 1), direction);
2260
2273
  }
2261
2274
 
2262
- inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
2275
+ inline void gemm(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n,
2263
2276
  int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
2264
2277
  library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
2265
2278
  library_data_t scaling_type) {
@@ -2326,7 +2339,7 @@ namespace dpct
2326
2339
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2327
2340
  library_data_t::real_float, library_data_t::real_float):
2328
2341
  {
2329
- detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2342
+ detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2330
2343
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2331
2344
  break;
2332
2345
  }
@@ -2365,7 +2378,7 @@ namespace dpct
2365
2378
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2366
2379
  library_data_t::real_bfloat16, library_data_t::real_float):
2367
2380
  {
2368
- detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2381
+ detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2369
2382
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2370
2383
  break;
2371
2384
  }
@@ -2407,7 +2420,7 @@ namespace dpct
2407
2420
  /// \param [in] ldc Leading dimension of C.
2408
2421
  /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2409
2422
  /// \param [in] scaling_type Data type of the scaling factors.
2410
- inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
2423
+ inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2411
2424
  int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2412
2425
  const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2413
2426
  library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
@@ -2445,7 +2458,7 @@ namespace dpct
2445
2458
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2446
2459
  library_data_t::real_bfloat16, library_data_t::real_float):
2447
2460
  {
2448
- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2461
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2449
2462
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2450
2463
  break;
2451
2464
  }
@@ -2453,7 +2466,7 @@ namespace dpct
2453
2466
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2454
2467
  library_data_t::real_float, library_data_t::real_float):
2455
2468
  {
2456
- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2469
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2457
2470
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2458
2471
  break;
2459
2472
  }
@@ -2529,7 +2542,7 @@ namespace dpct
2529
2542
  /// \param [in] stride_c Stride between the different C matrices.
2530
2543
  /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2531
2544
  /// \param [in] scaling_type Data type of the scaling factors.
2532
- inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
2545
+ inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2533
2546
  int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
2534
2547
  long long int stride_a, const void * b, library_data_t b_type, int ldb,
2535
2548
  long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
@@ -2602,7 +2615,7 @@ namespace dpct
2602
2615
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2603
2616
  library_data_t::real_bfloat16, library_data_t::real_float):
2604
2617
  {
2605
- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2618
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2606
2619
  q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2607
2620
  batch_size);
2608
2621
  break;
@@ -2611,7 +2624,7 @@ namespace dpct
2611
2624
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2612
2625
  library_data_t::real_float, library_data_t::real_float):
2613
2626
  {
2614
- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2627
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2615
2628
  q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2616
2629
  batch_size);
2617
2630
  break;
@@ -2952,6 +2965,810 @@ namespace dpct
2952
2965
  atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
2953
2966
  }
2954
2967
 
2968
+ inline unsigned int byte_level_permute(
2969
+ unsigned int a, unsigned int b, unsigned int s) {
2970
+ unsigned int ret;
2971
+ ret = ((((std::uint64_t)b << 32 | a) >> (s & 0x7) * 8) & 0xff) |
2972
+ (((((std::uint64_t)b << 32 | a) >> ((s >> 4) & 0x7) * 8) & 0xff)
2973
+ << 8) |
2974
+ (((((std::uint64_t)b << 32 | a) >> ((s >> 8) & 0x7) * 8) & 0xff)
2975
+ << 16) |
2976
+ (((((std::uint64_t)b << 32 | a) >> ((s >> 12) & 0x7) * 8) & 0xff)
2977
+ << 24);
2978
+ return ret;
2979
+ }
2980
+
2981
+ inline uint32_t byte_level_permute_custom(
2982
+ uint32_t low32, uint32_t high32, uint32_t sel, int mode = 0) {
2983
+ constexpr uint16_t lookup[6][4] = {
2984
+ {0x3210, 0x4321, 0x5432, 0x6543}, // Forward 4-byte extract
2985
+ {0x5670, 0x6701, 0x7012, 0x0123}, // Backward 4-byte extract
2986
+ {0x0000, 0x1111, 0x2222, 0x3333}, // Replicate 8-bit values
2987
+ {0x3210, 0x3211, 0x3222, 0x3333}, // Edge clamp left
2988
+ {0x0000, 0x1110, 0x2210, 0x3210}, // Edge clamp right
2989
+ {0x1010, 0x3232, 0x1010, 0x3232} // Replicate 16-bit values
2990
+ };
2991
+
2992
+ if (mode >= 1 && mode <= 6) {
2993
+ return byte_level_permute(low32, high32, lookup[mode - 1][sel & 0x3]);
2994
+ } else if (!mode) {
2995
+ return byte_level_permute(low32, high32, sel);
2996
+ }
2997
+ return 0;
2998
+ }
2999
+
3000
+ template <int n_nondefault_params, int n_default_params, typename T>
3001
+ class args_selector;
3002
+
3003
+ /// args_selector is a helper class for extracting arguments from an
3004
+ /// array of pointers to arguments or buffer of arguments to pass to a
3005
+ /// kernel function.
3006
+ ///
3007
+ /// \param R(Ts...) The type of the kernel
3008
+ /// \param n_nondefault_params The number of nondefault parameters of the
3009
+ /// kernel (excluding parameters that like sycl::nd_item, etc.) \param
3010
+ /// n_default_params The number of default parameters of the kernel
3011
+ ///
3012
+ /// Example usage:
3013
+ /// With the following kernel:
3014
+ /// void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float
3015
+ /// f=.1) {}
3016
+ /// and with the declaration:
3017
+ /// args_selector<2, 1, decltype(foo)> selector(kernelParams, extra);
3018
+ /// we have:
3019
+ /// selector.get<0>() returns a reference to sycl::float*,
3020
+ /// selector.get<1>() returns a reference to int,
3021
+ /// selector.get<2>() returns a reference to float
3022
+ template <int n_nondefault_params, int n_default_params, typename R,
3023
+ typename... Ts>
3024
+ class args_selector<n_nondefault_params, n_default_params, R(Ts...)> {
3025
+ private:
3026
+ void **kernel_params;
3027
+ char *args_buffer;
3028
+
3029
+ template <int i> static constexpr int account_for_default_params() {
3030
+ constexpr int n_total_params = sizeof...(Ts);
3031
+ if constexpr (i >= n_nondefault_params) {
3032
+ return n_total_params - n_default_params +
3033
+ (i - n_nondefault_params);
3034
+ } else {
3035
+ return i;
3036
+ }
3037
+ }
3038
+
3039
+ public:
3040
+ /// Get the type of the ith argument of R(Ts...)
3041
+ /// \param [in] i Index of parameter to get
3042
+ /// \returns Type of ith parameter
3043
+ template <int i>
3044
+ using arg_type = std::tuple_element_t<account_for_default_params<i>(),
3045
+ std::tuple<Ts...>>;
3046
+ static constexpr int params_num = sizeof...(Ts);
3047
+
3048
+ private:
3049
+ template <int i> static constexpr int get_offset() {
3050
+ if constexpr (i == 0) {
3051
+ // we can assume args_buffer is properly aligned to the
3052
+ // first argument
3053
+ return 0;
3054
+ } else {
3055
+ constexpr int prev_off = get_offset<i - 1>();
3056
+ constexpr int prev_past_end =
3057
+ prev_off + sizeof(arg_type<i - 1>);
3058
+ using T = arg_type<i>;
3059
+ // is the past-the-end of the i-1st element properly aligned
3060
+ // with the ith element's alignment?
3061
+ if constexpr (prev_past_end % alignof(T) == 0) {
3062
+ return prev_past_end;
3063
+ }
3064
+ // otherwise bump prev_past_end to match alignment
3065
+ else {
3066
+ return prev_past_end +
3067
+ (alignof(T) - (prev_past_end % alignof(T)));
3068
+ }
3069
+ }
3070
+ }
3071
+
3072
+ static char *get_args_buffer(void **extra) {
3073
+ if (!extra)
3074
+ return nullptr;
3075
+ for (; (std::size_t)*extra != 0; ++extra) {
3076
+ if ((std::size_t)*extra == 1) {
3077
+ return static_cast<char *>(*(extra + 1));
3078
+ }
3079
+ }
3080
+ return nullptr;
3081
+ }
3082
+
3083
+ public:
3084
+ /// If kernel_params is nonnull, then args_selector will
3085
+ /// extract arguments from kernel_params. Otherwise, it
3086
+ /// will extract them from extra.
3087
+ /// \param [in] kernel_params Array of pointers to arguments
3088
+ /// a or null pointer.
3089
+ /// \param [in] extra Array containing pointer to argument buffer.
3090
+ args_selector(void **kernel_params, void **extra)
3091
+ : kernel_params(kernel_params),
3092
+ args_buffer(get_args_buffer(extra)) {}
3093
+
3094
+ /// Get a reference to the ith argument extracted from kernel_params
3095
+ /// or extra.
3096
+ /// \param [in] i Index of argument to get
3097
+ /// \returns Reference to the ith argument
3098
+ template <int i> arg_type<i> &get() {
3099
+ if (kernel_params) {
3100
+ return *static_cast<arg_type<i> *>(kernel_params[i]);
3101
+ } else {
3102
+ return *reinterpret_cast<arg_type<i> *>(args_buffer +
3103
+ get_offset<i>());
3104
+ }
3105
+ }
3106
+ }; // COPY from DPCT head file
3107
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
3108
+
3109
+ /// Utility class for launching SYCL kernels through kernel
3110
+ /// function wrapper.
3111
+ /// For example:
3112
+ /// A SYCL kernel function:
3113
+ /// void kernel_func(int *ptr, sycl::nd_item<3> item);
3114
+ /// Kernel function wrapper:
3115
+ /// void kernel_func_wrapper(int *ptr) {
3116
+ /// sycl::queue queue = *dpct::kernel_launcher::_que;
3117
+ /// unsigned int localMemSize = dpct::kernel_launcher::_local_mem_size;
3118
+ /// sycl::nd_range<3> nr = dpct::kernel_launcher::_nr;
3119
+ /// queue.parallel_for(
3120
+ /// nr,
3121
+ /// [=](sycl::nd_item<3> item_ct1) {
3122
+ /// kernel_func(ptr, item_ct1);
3123
+ /// });
3124
+ /// }
3125
+ /// Then launch the kernel through wrapper like:
3126
+ /// typedef void(*fpt)(int *);
3127
+ /// fpt fp = kernel_func_wrapper;
3128
+ /// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0,
3129
+ /// device_ptr);
3130
+ /// If the origin function type is erased, then need to register it first:
3131
+ /// void *fp = (void *)wrapper_register(&kernel_func_wrapper).get();
3132
+ /// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), args,
3133
+ /// 0, 0);
3134
+ class kernel_launcher {
3135
+ template <typename FuncT, typename ArgSelector, std::size_t... Index>
3136
+ static void launch_helper(FuncT &&func, ArgSelector &selector,
3137
+ std::index_sequence<Index...>) {
3138
+ func(selector.template get<Index>()...);
3139
+ }
3140
+ static void set_execution_config(dim3 group_range, dim3 local_range,
3141
+ unsigned int local_mem_size,
3142
+ queue_ptr que) {
3143
+ if (que) {
3144
+ _que = que;
3145
+ } else {
3146
+ _que = &get_default_queue();
3147
+ }
3148
+ _nr = sycl::nd_range<3>(
3149
+ static_cast<sycl::range<3>>(group_range * local_range),
3150
+ static_cast<sycl::range<3>>(local_range));
3151
+ _local_mem_size = local_mem_size;
3152
+
3153
+
3154
+ };
3155
+ static inline std::mutex kernel_function_ptr_map_mutex;
3156
+
3157
+ public:
3158
+ /// Variables for storing execution configuration.
3159
+ static inline thread_local sycl::queue *_que = nullptr;
3160
+ static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>();
3161
+ static inline thread_local unsigned int _local_mem_size = 0;
3162
+ /// Map for retrieving launchable functor from a raw pointer.
3163
+ static inline std::map<
3164
+ const void *,
3165
+ std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>>
3166
+ kernel_function_ptr_map = {};
3167
+
3168
+ /// Registers a kernel function pointer with a corresponding launchable
3169
+ /// functor.
3170
+ /// \param [in] func Pointer to the kernel function.
3171
+ /// \param [in] launcher Functor to handle kernel invocation.
3172
+ static void register_kernel_ptr(
3173
+ const void *func,
3174
+ std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>
3175
+ launcher) {
3176
+ std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
3177
+ kernel_function_ptr_map[func] = std::move(launcher);
3178
+ }
3179
+ /// Launches a kernel function with arguments provided directly through
3180
+ /// kernel function wrapper.
3181
+ /// \tparam FuncT Type of the kernel function wrapper.
3182
+ /// \tparam ArgsT Types of kernel arguments.
3183
+ /// \param [in] func Pointer to the kernel function wrapper.
3184
+ /// \param [in] group_range SYCL group range.
3185
+ /// \param [in] local_range SYCL local range.
3186
+ /// \param [in] local_mem_size The size of local memory required by the
3187
+ /// kernel function. \param [in] que SYCL queue used to execute kernel.
3188
+ /// \param [in] args Kernel arguments.
3189
+ template <typename FuncT, typename... ArgsT>
3190
+ static std::enable_if_t<std::is_invocable_v<FuncT *, ArgsT...>, void>
3191
+ launch(FuncT *func, dim3 group_range, dim3 local_range,
3192
+ unsigned int local_mem_size, queue_ptr que, ArgsT... args) {
3193
+ set_execution_config(group_range, local_range, local_mem_size, que);
3194
+ func(args...);
3195
+ }
3196
+ /// Launches a kernel function through registered kernel function
3197
+ /// wrapper. \param [in] func Pointer to the registered kernel function
3198
+ /// wrapper. \param [in] group_range SYCL group range. \param [in]
3199
+ /// local_range SYCL local range. \param [in] args Array of pointers to
3200
+ /// kernel arguments. \param [in] local_mem_size The size of local
3201
+ /// memory required by the kernel function. \param [in] que SYCL queue
3202
+ /// used to execute kernel.
3203
+ static void launch(const void *func, dim3 group_range, dim3 local_range,
3204
+ void **args, unsigned int local_mem_size,
3205
+ queue_ptr que) {
3206
+ std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
3207
+ auto Iter = kernel_function_ptr_map.find(func);
3208
+ if (Iter == kernel_function_ptr_map.end()) {
3209
+ throw std::runtime_error("dpct::launch() : no registered "
3210
+ "kernel function wrapper found.");
3211
+ }
3212
+ (Iter->second)(group_range, local_range, args, local_mem_size, que);
3213
+ }
3214
+ /// Launches a kernel function with packed arguments through kernel
3215
+ /// function wrapper.
3216
+ /// \tparam FuncT Type of the kernel function wrapper.
3217
+ /// \param [in] func Pointer to the kernel function wrapper.
3218
+ /// \param [in] group_range SYCL group range.
3219
+ /// \param [in] local_range SYCL local range.
3220
+ /// \param [in] args Array of pointers to kernel arguments.
3221
+ /// \param [in] local_mem_size The size of local memory required by the
3222
+ /// kernel function. \param [in] que SYCL queue used to execute kernel.
3223
+ template <typename FuncT>
3224
+ static std::enable_if_t<std::is_function_v<FuncT>, void>
3225
+ launch(FuncT *func, dim3 group_range, dim3 local_range, void **args,
3226
+ unsigned int local_mem_size, queue_ptr que) {
3227
+ constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num;
3228
+ set_execution_config(group_range, local_range, local_mem_size, que);
3229
+ args_selector<p_num, p_num, FuncT> selector(args, nullptr);
3230
+ launch_helper(func, selector, std::make_index_sequence<p_num>{});
3231
+ }
3232
+ }; // COPY from DPCT head file
3233
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/kernel.hpp
3234
+
3235
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
3236
+ template <typename T>
3237
+ T select_from_sub_group(
3238
+ sycl::sub_group g,
3239
+ T x,
3240
+ int remote_local_id,
3241
+ int logical_sub_group_size = 32) {
3242
+ unsigned int start_index = g.get_local_linear_id() /
3243
+ logical_sub_group_size *
3244
+ logical_sub_group_size;
3245
+ return sycl::select_from_group(
3246
+ g, x, start_index + remote_local_id % logical_sub_group_size);
3247
+ }
3248
+
3249
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
3250
+ template <typename T>
3251
+ void ldmatrix(uintptr_t addr, T* m, bool trans = false, unsigned mat = 0) {
3252
+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
3253
+ int lane = sg.get_local_linear_id();
3254
+
3255
+ int lane_group8_row = lane / 8;
3256
+ int lane_group8_col = lane % 8;
3257
+
3258
+ if (!trans) {
3259
+ // calculate the source lane
3260
+ int src_lane = 2 * lane_group8_row;
3261
+ if (lane_group8_col >= 4)
3262
+ src_lane += 1;
3263
+
3264
+ // Broadcast the address from the source lane
3265
+ auto recv_addr_uintp =
3266
+ dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
3267
+
3268
+ // Cast the received address from uintptr_t to the type of 'm'
3269
+ auto recv_addr = reinterpret_cast<T*>(recv_addr_uintp);
3270
+
3271
+ // Non-transposed load
3272
+ *m = recv_addr[lane_group8_col % 4];
3273
+ } else {
3274
+ // calculate the source lane
3275
+ int src_lane = (lane % 4) * 2;
3276
+
3277
+ // Broadcast the address from the source lane
3278
+ auto recv_addr_uintp_1 =
3279
+ dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
3280
+ auto recv_addr_uintp_2 =
3281
+ dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
3282
+
3283
+ // Cast the received address from uintptr_t to 'half *'
3284
+ auto recv_addr_1 = reinterpret_cast<sycl::half*>(recv_addr_uintp_1);
3285
+ auto recv_addr_2 = reinterpret_cast<sycl::half*>(recv_addr_uintp_2);
3286
+
3287
+ // Transposed load
3288
+ int index = lane / 4;
3289
+ sycl::half val0 = recv_addr_1[index];
3290
+ sycl::half val1 = recv_addr_2[index];
3291
+
3292
+ // Combine the two 16-bits into one 32-bit value
3293
+ sycl::half2 val = sycl::half2(val0, val1);
3294
+ *m = *reinterpret_cast<T*>(&val);
3295
+ }
3296
+ }
3297
+
3298
+ template <typename T>
3299
+ void ldmatrix(uintptr_t addr, T* m1, T* m2, bool trans = false) {
3300
+ // Load 1st matrix
3301
+ ldmatrix(addr, m1, trans, 0);
3302
+ // Load 2nd matrix
3303
+ ldmatrix(addr, m2, trans, 1);
3304
+ }
3305
+
3306
+ template <typename T>
3307
+ void ldmatrix(
3308
+ uintptr_t addr, T* m1, T* m2, T* m3, T* m4, bool trans = false) {
3309
+ // Load 1st matrix
3310
+ ldmatrix(addr, m1, trans, 0);
3311
+ // Load 2nd matrix
3312
+ ldmatrix(addr, m2, trans, 1);
3313
+ // Load 3rd matrix
3314
+ ldmatrix(addr, m3, trans, 2);
3315
+ // Load 4th matrix
3316
+ ldmatrix(addr, m4, trans, 3);
3317
+ }
3318
+
3319
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
3320
+
3321
+ /// A helper struct that defines the pack type for the input matrix
3322
+ /// fragments
3323
+ /// of mma() function based on the type of input matrix fragments.
3324
+ /// The MMAType struct is specialized for different types of input matrices.
3325
+ /// Currently, the specialization for f16, bf16 and s8 types is defined
3326
+ /// below. \tparam [in] T The type of the input matrix fragments
3327
+ template <typename T>
3328
+ struct MMAType {
3329
+ using PackType = uint32_t;
3330
+ };
3331
+
3332
+ /// Each work item of a sub-group (limited to size 32) calling this function
3333
+ /// calculates a subset fragment for the output matrix D using MAD operation
3334
+ /// on A, B & C matrix fragments (D = A * B + C). Current supported shapes &
3335
+ /// types:
3336
+ /// - m8n8k4 (f32.f16.f16.f32)
3337
+ /// - m8n8k16 (s32.s8.s8.s32)
3338
+ /// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
3339
+ /// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32)
3340
+ /// - m16n8k32 (s32.s8.s8.s32)
3341
+ /// Here, m, n & k define the shapes of A, B & C matrices respectively
3342
+ /// (A = [m x k], B = [k x n], C = [m x n]).
3343
+ /// \tparam [in] M The rows of A, C & D matrices
3344
+ /// \tparam [in] N The columns of B, C, D matrices
3345
+ /// \tparam [in] K The columns & rows of A & B matrices respectively
3346
+ /// \tparam [in] ABType The type of the input matrix (A & B) fragment
3347
+ /// \tparam [in] CDType The type of the output matrix (C & D) fragment
3348
+ /// \param [out] d_mat_frag The fragment of the output matrix D to store the
3349
+ /// result of A * B + C
3350
+ /// \param [in] a_mat_frag The fragment of the input matrix A to be
3351
+ /// multiplied with B matrix fragment \param [in] b_mat_frag The fragment of
3352
+ /// the input matrix B to be multiplied with A matrix fragment \param [in]
3353
+ /// c_mat_frag The fragment of the input matrix C to be added with the
3354
+ /// result of A * B fragments
3355
+ template <int M, int N, int K, typename ABType, typename CDType>
3356
+ void mma(
3357
+ volatile void** d_mat_frag,
3358
+ void* a_mat_frag,
3359
+ void* b_mat_frag,
3360
+ void* c_mat_frag) {
3361
+ auto d = reinterpret_cast<volatile CDType**>(d_mat_frag);
3362
+ auto a =
3363
+ reinterpret_cast<typename MMAType<ABType>::PackType*>(a_mat_frag);
3364
+ auto b =
3365
+ reinterpret_cast<typename MMAType<ABType>::PackType*>(b_mat_frag);
3366
+ auto c = reinterpret_cast<CDType*>(c_mat_frag);
3367
+
3368
+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
3369
+ int lane = sg.get_local_linear_id();
3370
+
3371
+ static_assert(
3372
+ (M == 8 && N == 8 && K == 4) || (M == 8 && N == 8 && K == 16) ||
3373
+ (M == 16 && N == 8 && K == 8) || (M == 16 && N == 8 && K == 16) ||
3374
+ (M == 16 && N == 8 && K == 32),
3375
+ "Unsupported MMA shape!");
3376
+
3377
+ short row_load_offset = 4 * (lane >> 2);
3378
+ short col_load_offset = 8 * (lane % 4);
3379
+
3380
+ if constexpr (M == 8 && N == 8 && K == 4) {
3381
+ if constexpr (std::is_floating_point_v<CDType>) {
3382
+ col_load_offset = row_load_offset % 16;
3383
+
3384
+ // Init D matrix with fragments of C matrix
3385
+ *d[0] = c[0];
3386
+ *d[1] = c[1];
3387
+ *d[2] = c[2];
3388
+ *d[3] = c[3];
3389
+ *d[4] = c[4];
3390
+ *d[5] = c[5];
3391
+ *d[6] = c[6];
3392
+ *d[7] = c[7];
3393
+
3394
+ // Calculate the row and col offset indices to iterate through the row
3395
+ // & col fragments of A & B matrices
3396
+ int r_ind = (lane % 2) ? 1 : 0;
3397
+ int c_ind = ((lane % 4) / 2) ? 2 : 0;
3398
+
3399
+ // Each sub-group is responsible for computing a fragment size of 8*8
3400
+ // elements of matrix D for each of 4 MMA computations.
3401
+ // Each work item computes 8 elements of matrix D by gathering
3402
+ // their corresponding col & row matrix fragments of length k (4)
3403
+ // from A & B matrices respectively using below mapping logic:
3404
+ // row0 = (i % 4) if (lane < 16) else (i % 4) + 4
3405
+ // col0 = (lane % 4)
3406
+ // As each row & col fragment of A & B matrices is distributed across
3407
+ // 4 work items, each iteration of below loop loads a partial fragment
3408
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3409
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3410
+
3411
+ for (int i = 0; i < 4; i++) {
3412
+ // Load partial fragment from col0 of matrix A ({a0, a1})
3413
+ recv_a[0] =
3414
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3415
+ // Load partial fragment from col0 of matrix A ({a2, a3})
3416
+ recv_a[1] =
3417
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3418
+
3419
+ // Load partial fragment from row0 of matrix B ({b0, b1})
3420
+ recv_b[0] =
3421
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3422
+ // Load partial fragment from row0 of matrix B ({b2, b3})
3423
+ recv_b[1] =
3424
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
3425
+
3426
+ auto ra = reinterpret_cast<ABType*>(recv_a);
3427
+ auto rb = reinterpret_cast<ABType*>(recv_b);
3428
+
3429
+ // Each work item calculates a partial product of A & B matrix
3430
+ // fragments and adds it to the corresponding D matrix fragment (for
3431
+ // even work item indices) d0 += col0{ a0 } * row0{ b0 } d1 += col0{
3432
+ // a0 } * row0{ b1 } d2 += col1{ a2 } * row0{ b0 } d3 += col1{ a2 }
3433
+ // * row0{ b1 } (for odd work item indices) d0 += col0{ a1 } * row0{
3434
+ // b2 } d1 += col0{ a1 } * row0{ b3 } d2 += col1{ a3 } * row0{ b2 }
3435
+ // d3 += col1{ a3 } * row0{ b3 }
3436
+ *d[0] +=
3437
+ static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
3438
+ *d[1] += static_cast<float>(ra[r_ind]) *
3439
+ static_cast<float>(rb[c_ind + 1]);
3440
+ *d[2] += static_cast<float>(ra[r_ind + 2]) *
3441
+ static_cast<float>(rb[c_ind]);
3442
+ *d[3] += static_cast<float>(ra[r_ind + 2]) *
3443
+ static_cast<float>(rb[c_ind + 1]);
3444
+
3445
+ // Load partial fragment from row1 of matrix B ({b0, b1})
3446
+ recv_b[0] =
3447
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 16);
3448
+ // Load partial fragment from row1 of matrix B ({b2, b3})
3449
+ recv_b[1] =
3450
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 16);
3451
+
3452
+ // (for even work item indices)
3453
+ // d0 += col0{ a0 } * row1{ b0 }
3454
+ // d1 += col0{ a0 } * row1{ b1 }
3455
+ // d2 += col1{ a2 } * row1{ b0 }
3456
+ // d3 += col1{ a2 } * row1{ b1 }
3457
+ // (for odd work item indices)
3458
+ // d0 += col0{ a1 } * row1{ b2 }
3459
+ // d1 += col0{ a1 } * row1{ b3 }
3460
+ // d2 += col1{ a3 } * row1{ b2 }
3461
+ // d3 += col1{ a3 } * row1{ b3 }
3462
+ *d[4] +=
3463
+ static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
3464
+ *d[5] += static_cast<float>(ra[r_ind]) *
3465
+ static_cast<float>(rb[c_ind + 1]);
3466
+ *d[6] += static_cast<float>(ra[r_ind + 2]) *
3467
+ static_cast<float>(rb[c_ind]);
3468
+ *d[7] += static_cast<float>(ra[r_ind + 2]) *
3469
+ static_cast<float>(rb[c_ind + 1]);
3470
+ }
3471
+ }
3472
+ } else if constexpr (M == 8 && N == 8 && K == 16) {
3473
+ if constexpr (std::is_integral_v<ABType>) {
3474
+ // Init D matrix with fragments of C matrix
3475
+ *d[0] = c[0];
3476
+ *d[1] = c[1];
3477
+
3478
+ // Each sub-group is responsible for computing a fragment size of 16*8
3479
+ // elements of matrix D.
3480
+ // Each work item computes 2 elements of matrix D by gathering
3481
+ // their corresponding row & col matrix fragments of length k (16)
3482
+ // from A & B matrices respectively using below mapping logic:
3483
+ // row0 = ((lane % 4) * 4) + i
3484
+ // col0 = (lane >> 2)
3485
+ // As each row & col fragment of A & B matrices is distributed across
3486
+ // 4 work items, each iteration of below loop loads a partial fragment
3487
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3488
+ for (int i = 0; i < 4; i++) {
3489
+ typename MMAType<ABType>::PackType recv_a, recv_b[2];
3490
+
3491
+ // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
3492
+ recv_a = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3493
+ // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
3494
+ recv_b[0] =
3495
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3496
+ // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
3497
+ recv_b[1] =
3498
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
3499
+
3500
+ auto a = reinterpret_cast<ABType*>(&recv_a);
3501
+ auto b = reinterpret_cast<ABType*>(recv_b);
3502
+
3503
+ // Each work item calculates a partial product of A & B matrix
3504
+ // fragments and adds it to the corresponding D matrix fragment d0
3505
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
3506
+ // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row0{ a0, a1, a2,
3507
+ // a3 } * col0{ b0, b1, b2, b3 } d3 += row0{ a0, a1, a2, a3 } *
3508
+ // col1{ b0, b1, b2, b3 }
3509
+ for (int j = 0; j < 4; j++) {
3510
+ *d[0] += a[j] * b[j];
3511
+ *d[1] += a[j] * b[j + 4];
3512
+ }
3513
+ }
3514
+ }
3515
+ } else if constexpr (M == 16 && N == 8 && K == 8) {
3516
+ if constexpr (std::is_floating_point_v<CDType>) {
3517
+ // Init D matrix fragment with C matrix fragment
3518
+ *d[0] = c[0];
3519
+ *d[1] = c[1];
3520
+ *d[2] = c[2];
3521
+ *d[3] = c[3];
3522
+
3523
+ // Each sub-group is responsible for computing a fragment size of 16*8
3524
+ // elements of matrix D.
3525
+ // Each work item computes 4 elements of matrix D by gathering
3526
+ // their corresponding row & col matrix fragments of length k (8)
3527
+ // from A & B matrices respectively using below mapping logic:
3528
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
3529
+ // col0 = (lane % 4) * 2 + (i & 0x1)
3530
+ // As each row & col fragment of A & B matrices is distributed across
3531
+ // 4 work items, each iteration of below loop loads a partial fragment
3532
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3533
+ for (int i = 0; i < 4; i++) {
3534
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3535
+
3536
+ // Load partial fragment from row0 of matrix A ({a0, a1})
3537
+ recv_a[0] =
3538
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3539
+ // Load partial fragment from row1 of matrix A ({a2, a3})
3540
+ recv_a[1] =
3541
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3542
+ // Load partial fragment from col0 of matrix B ({b0, b1})
3543
+ recv_b[0] =
3544
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3545
+ // Load partial fragment from col1 of matrix B ({b0, b1})
3546
+ recv_b[1] =
3547
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
3548
+
3549
+ auto ra = reinterpret_cast<ABType*>(recv_a);
3550
+ auto rb = reinterpret_cast<ABType*>(recv_b);
3551
+
3552
+ // Each work item calculates a partial product of A & B matrix
3553
+ // fragments and adds it to the corresponding D matrix fragment d0
3554
+ // += row0{ a0, a1 } * col0{ b0, b1 } d1 += row0{ a0, a1 } * col1{
3555
+ // b0, b1 } d2 += row1{ a2, a3 } * col0{ b0, b1 } d3 += row1{ a2, a3
3556
+ // } * col1{ b0, b1 }
3557
+ for (int j = 0; j < 2; j++) {
3558
+ *d[0] += static_cast<float>(ra[j]) * static_cast<float>(rb[j]);
3559
+ *d[1] +=
3560
+ static_cast<float>(ra[j]) * static_cast<float>(rb[j + 2]);
3561
+ *d[2] +=
3562
+ static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j]);
3563
+ *d[3] +=
3564
+ static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j + 2]);
3565
+ }
3566
+ }
3567
+ }
3568
+ } else if constexpr (M == 16 && N == 8 && K == 16) {
3569
+ if constexpr (std::is_floating_point_v<CDType>) {
3570
+ // Init D matrix fragment with C matrix fragment
3571
+ *d[0] = c[0];
3572
+ *d[1] = c[1];
3573
+ *d[2] = c[2];
3574
+ *d[3] = c[3];
3575
+
3576
+ // Each sub-group is responsible for computing a fragment size of 16*8
3577
+ // elements of matrix D.
3578
+ // Each work item computes 4 elements of matrix D by gathering
3579
+ // their corresponding row & col matrix fragments of length k (8)
3580
+ // from A & B matrices respectively using below mapping logic:
3581
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
3582
+ // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
3583
+ // As each row & col fragment of A & B matrices is distributed across
3584
+ // 4 work items, each iteration of below loop loads a partial fragment
3585
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3586
+ for (int i = 0; i < 4; i++) {
3587
+ typename MMAType<ABType>::PackType recv_a[4], recv_b[4];
3588
+
3589
+ // Load partial fragment from row0 of matrix A ({a0, a1})
3590
+ recv_a[0] =
3591
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3592
+ // Load partial fragment from row0 of matrix A ({a2, a3})
3593
+ recv_a[1] =
3594
+ dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
3595
+ // Load partial fragment from row1 of matrix A ({a0, a1})
3596
+ recv_a[2] =
3597
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3598
+ // Load partial fragment from row1 of matrix A ({a2, a3})
3599
+ recv_a[3] =
3600
+ dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
3601
+
3602
+ // Load partial fragment from col0 of matrix B ({b0, b1})
3603
+ recv_b[0] =
3604
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3605
+ // Load partial fragment from col0 of matrix B ({b2, b3})
3606
+ recv_b[1] =
3607
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
3608
+ // Load partial fragment from col1 of matrix B ({b0, b1})
3609
+ recv_b[2] =
3610
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);
3611
+ // Load partial fragment from col1 of matrix B ({b2, b3})
3612
+ recv_b[3] =
3613
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);
3614
+
3615
+ auto ra = reinterpret_cast<ABType*>(recv_a);
3616
+ auto rb = reinterpret_cast<ABType*>(recv_b);
3617
+
3618
+ // Each work item calculates a partial product of A & B matrix
3619
+ // fragments and adds it to the corresponding D matrix fragment d0
3620
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
3621
+ // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, a2,
3622
+ // a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } *
3623
+ // col1{ b0, b1, b2, b3 }
3624
+ for (int j = 0; j < 4; j++) {
3625
+ *d[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
3626
+ *d[1] +=
3627
+ static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]);
3628
+ *d[2] +=
3629
+ static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]);
3630
+ *d[3] += static_cast<CDType>(ra[j + 4]) *
3631
+ static_cast<CDType>(rb[j + 4]);
3632
+ }
3633
+ }
3634
+ } else if constexpr (std::is_integral_v<ABType>) {
3635
+ // Init D matrix with fragments of C matrix
3636
+ *d[0] = c[0];
3637
+ *d[1] = c[1];
3638
+ *d[2] = c[2];
3639
+ *d[3] = c[3];
3640
+
3641
+ // Each sub-group is responsible for computing a fragment size of 16*8
3642
+ // elements of matrix D.
3643
+ // Each work item computes 4 elements of matrix D by gathering
3644
+ // their corresponding row & col matrix fragments of length k (8)
3645
+ // from A & B matrices respectively using below mapping logic:
3646
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
3647
+ // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
3648
+ // As each row & col fragment of A & B matrices is distributed across
3649
+ // 4 work items, each iteration of below loop loads a partial fragment
3650
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3651
+ for (int i = 0; i < 4; i++) {
3652
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3653
+
3654
+ // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
3655
+ recv_a[0] =
3656
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3657
+ // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
3658
+ recv_a[1] =
3659
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3660
+ // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
3661
+ recv_b[0] =
3662
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3663
+ // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
3664
+ recv_b[1] =
3665
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
3666
+
3667
+ auto ra = reinterpret_cast<ABType*>(recv_a);
3668
+ auto rb = reinterpret_cast<ABType*>(recv_b);
3669
+
3670
+ // Each work item calculates a partial product of A & B matrix
3671
+ // fragments and adds it to the corresponding D matrix fragment d0
3672
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
3673
+ // a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } d2 += row1{ a4, a5, a6,
3674
+ // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
3675
+ // col1{ b4, b5, b6, b7 }
3676
+ for (int i = 0; i < 4; i++) {
3677
+ *d[0] += ra[i] * rb[i];
3678
+ *d[1] += ra[i] * rb[i + 4];
3679
+ *d[2] += ra[i + 4] * rb[i];
3680
+ *d[3] += ra[i + 4] * rb[i + 4];
3681
+ }
3682
+ }
3683
+ }
3684
+ } else if constexpr (M == 16 && N == 8 && K == 32) {
3685
+ if constexpr (std::is_integral_v<ABType>) {
3686
+ // Init D matrix with fragments of C matrix
3687
+ *d[0] = c[0];
3688
+ *d[1] = c[1];
3689
+ *d[2] = c[2];
3690
+ *d[3] = c[3];
3691
+
3692
+ // Each sub-group is responsible for computing a fragment size of 16*8
3693
+ // elements of matrix D.
3694
+ // Each work item computes 4 elements of matrix D by gathering
3695
+ // their corresponding row & col matrix fragments of length k (32)
3696
+ // from A & B matrices respectively using below mapping logic:
3697
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
3698
+ // col0 = ((lane % 4) * 4) + (i & 0x3) & col1 = ((lane % 4) * 4) + (i
3699
+ // & 0x3) As each row & col fragment of A & B matrices is distributed
3700
+ // across 4 work items, each iteration of below loop loads a partial
3701
+ // fragment of matrix A (row) and matrix B (col) using the row & col
3702
+ // offsets.
3703
+ for (int i = 0; i < 4; i++) {
3704
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3705
+
3706
+ // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
3707
+ recv_a[0] =
3708
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3709
+ // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
3710
+ recv_a[1] =
3711
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3712
+ // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
3713
+ recv_b[0] =
3714
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3715
+ // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
3716
+ recv_b[1] =
3717
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
3718
+
3719
+ auto a = reinterpret_cast<ABType*>(recv_a);
3720
+ auto b = reinterpret_cast<ABType*>(recv_b);
3721
+
3722
+ // Each work item calculates a partial product of A & B matrix
3723
+ // fragments and adds it to the corresponding D matrix fragment d0
3724
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
3725
+ // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a4, a5, a6,
3726
+ // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
3727
+ // col1{ b0, b1, b2, b3 }
3728
+ for (int j = 0; j < 4; j++) {
3729
+ *d[0] += a[j] * b[j];
3730
+ *d[1] += a[j] * b[j + 4];
3731
+ *d[2] += a[j + 4] * b[j];
3732
+ *d[3] += a[j + 4] * b[j + 4];
3733
+ }
3734
+ }
3735
+
3736
+ for (int i = 0; i < 4; i++) {
3737
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3738
+
3739
+ // Load partial fragment from row0 of matrix A ({a8, a9, a10, a11})
3740
+ recv_a[0] =
3741
+ dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
3742
+ // Load partial fragment from row1 of matrix A ({a12, a13, a14,
3743
+ // a15})
3744
+ recv_a[1] =
3745
+ dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
3746
+ // Load partial fragment from col0 of matrix B ({b4, b5, b6, b7})
3747
+ recv_b[0] =
3748
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
3749
+ // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
3750
+ recv_b[1] =
3751
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 4);
3752
+
3753
+ auto a = reinterpret_cast<ABType*>(recv_a);
3754
+ auto b = reinterpret_cast<ABType*>(recv_b);
3755
+
3756
+ // Each work item calculates a partial product of A & B matrix
3757
+ // fragments and adds it to the corresponding D matrix fragment d0
3758
+ // += row0{ a8, a9, a10, a11 } * col0{ b4, b5, b6, b7 } d1 += row0{
3759
+ // a8, a9, a10, a11 } * col1{ b4, b5, b6, b7 } d2 += row1{ a12, a13,
3760
+ // a14, a15 } * col0{ b4, b5, b6, b7 } d3 += row1{ a12, a13, a14,
3761
+ // a15 } * col1{ b4, b5, b6, b7 }
3762
+ for (int j = 0; j < 4; j++) {
3763
+ *d[0] += a[j] * b[j];
3764
+ *d[1] += a[j] * b[j + 4];
3765
+ *d[2] += a[j + 4] * b[j];
3766
+ *d[3] += a[j + 4] * b[j + 4];
3767
+ }
3768
+ }
3769
+ }
3770
+ }
3771
+ }
2955
3772
  } // COPY from DPCT head files
2956
3773
 
2957
3774
  #endif // GGML_SYCL_DPCT_HELPER_HPP