whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -15,18 +15,9 @@
15
15
 
16
16
  #include <sycl/sycl.hpp>
17
17
  #include <sycl/half_type.hpp>
18
- #include <syclcompat/math.hpp>
19
- #include <map>
20
-
21
- #ifdef GGML_SYCL_USE_INTEL_ONEMKL
22
18
  #include <oneapi/mkl.hpp>
23
- // Allow to use the same namespace for Intel oneMKL and oneMath
24
- namespace oneapi {
25
- namespace math = mkl;
26
- }
27
- #else
28
- #include <oneapi/math.hpp>
29
- #endif
19
+
20
+ #include <map>
30
21
 
31
22
  #include "ggml.h"
32
23
 
@@ -92,32 +83,13 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
92
83
  }
93
84
 
94
85
  template <typename Ts> struct matrix_info_t {
95
- oneapi::math::transpose transpose_info[2];
86
+ oneapi::mkl::transpose transpose_info[2];
96
87
  Ts value_info[2];
97
88
  std::int64_t size_info[3];
98
89
  std::int64_t ld_info[3];
99
90
  std::int64_t groupsize_info;
100
91
  };
101
92
 
102
- inline auto get_onemath_backend(sycl::queue& queue)
103
- #if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
104
- -> sycl::queue&
105
- #endif
106
- {
107
- // If the backend is known at compile-time, use oneMath backend_selector to use
108
- // compile-time dispatching and avoid the need to dlopen libraries. Otherwise
109
- // fallback to runtime dispatching.
110
- #if defined(GGML_SYCL_NVIDIA)
111
- return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
112
- #elif defined(GGML_SYCL_AMD)
113
- return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
114
- #elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
115
- return queue;
116
- #else
117
- static_assert(false, "Unsupported backend");
118
- #endif
119
- }
120
-
121
93
  namespace dpct
122
94
  {
123
95
  typedef sycl::queue *queue_ptr;
@@ -1735,7 +1707,7 @@ namespace dpct
1735
1707
  namespace detail
1736
1708
  {
1737
1709
  template <class Ta, class Tb, class Tc, class Ts>
1738
- inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
1710
+ inline void gemm_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
1739
1711
  int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
1740
1712
  const void * beta, void * c, int ldc) {
1741
1713
  Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
@@ -1743,7 +1715,7 @@ namespace dpct
1743
1715
  auto data_a = get_memory<const Ta>(a);
1744
1716
  auto data_b = get_memory<const Tb>(b);
1745
1717
  auto data_c = get_memory<Tc>(c);
1746
- oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
1718
+ oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a,
1747
1719
  lda, data_b, ldb, beta_value, data_c, ldc);
1748
1720
  }
1749
1721
 
@@ -1775,7 +1747,7 @@ namespace dpct
1775
1747
  };
1776
1748
 
1777
1749
  template <class Ta, class Tb, class Tc, class Ts>
1778
- inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
1750
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1779
1751
  int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1780
1752
  int ldb, const void * beta, void ** c, int ldc, int batch_size,
1781
1753
  matrix_info_t<float> * matrix_info) {
@@ -1794,8 +1766,8 @@ namespace dpct
1794
1766
  matrix_info->ld_info[2] = ldc;
1795
1767
  matrix_info->groupsize_info = batch_size;
1796
1768
 
1797
- sycl::event e = oneapi::math::blas::column_major::gemm_batch(
1798
- get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
1769
+ sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1770
+ q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
1799
1771
  matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
1800
1772
  reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
1801
1773
  reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
@@ -1804,7 +1776,7 @@ namespace dpct
1804
1776
  }
1805
1777
 
1806
1778
  template <class Ta, class Tb, class Tc, class Ts>
1807
- inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
1779
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1808
1780
  int m, int n, int k, const void * alpha, const void * a, int lda,
1809
1781
  long long int stride_a, const void * b, int ldb, long long int stride_b,
1810
1782
  const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
@@ -1813,7 +1785,7 @@ namespace dpct
1813
1785
  auto data_a = get_memory<const Ta>(a);
1814
1786
  auto data_b = get_memory<const Tb>(b);
1815
1787
  auto data_c = get_memory<Tc>(c);
1816
- oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
1788
+ oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value,
1817
1789
  data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
1818
1790
  data_c, ldc, stride_c, batch_size);
1819
1791
  }
@@ -2300,7 +2272,7 @@ namespace dpct
2300
2272
  sycl::range<3>(x, y, 1), direction);
2301
2273
  }
2302
2274
 
2303
- inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
2275
+ inline void gemm(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n,
2304
2276
  int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
2305
2277
  library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
2306
2278
  library_data_t scaling_type) {
@@ -2367,7 +2339,7 @@ namespace dpct
2367
2339
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2368
2340
  library_data_t::real_float, library_data_t::real_float):
2369
2341
  {
2370
- detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2342
+ detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2371
2343
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2372
2344
  break;
2373
2345
  }
@@ -2406,7 +2378,7 @@ namespace dpct
2406
2378
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2407
2379
  library_data_t::real_bfloat16, library_data_t::real_float):
2408
2380
  {
2409
- detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2381
+ detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2410
2382
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2411
2383
  break;
2412
2384
  }
@@ -2448,7 +2420,7 @@ namespace dpct
2448
2420
  /// \param [in] ldc Leading dimension of C.
2449
2421
  /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2450
2422
  /// \param [in] scaling_type Data type of the scaling factors.
2451
- inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
2423
+ inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2452
2424
  int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2453
2425
  const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2454
2426
  library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
@@ -2486,7 +2458,7 @@ namespace dpct
2486
2458
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2487
2459
  library_data_t::real_bfloat16, library_data_t::real_float):
2488
2460
  {
2489
- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2461
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2490
2462
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2491
2463
  break;
2492
2464
  }
@@ -2494,7 +2466,7 @@ namespace dpct
2494
2466
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2495
2467
  library_data_t::real_float, library_data_t::real_float):
2496
2468
  {
2497
- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2469
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2498
2470
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2499
2471
  break;
2500
2472
  }
@@ -2570,7 +2542,7 @@ namespace dpct
2570
2542
  /// \param [in] stride_c Stride between the different C matrices.
2571
2543
  /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2572
2544
  /// \param [in] scaling_type Data type of the scaling factors.
2573
- inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
2545
+ inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2574
2546
  int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
2575
2547
  long long int stride_a, const void * b, library_data_t b_type, int ldb,
2576
2548
  long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
@@ -2643,7 +2615,7 @@ namespace dpct
2643
2615
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2644
2616
  library_data_t::real_bfloat16, library_data_t::real_float):
2645
2617
  {
2646
- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2618
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2647
2619
  q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2648
2620
  batch_size);
2649
2621
  break;
@@ -2652,7 +2624,7 @@ namespace dpct
2652
2624
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2653
2625
  library_data_t::real_float, library_data_t::real_float):
2654
2626
  {
2655
- detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2627
+ detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2656
2628
  q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2657
2629
  batch_size);
2658
2630
  break;
@@ -3025,6 +2997,778 @@ namespace dpct
3025
2997
  return 0;
3026
2998
  }
3027
2999
 
3000
+ template <int n_nondefault_params, int n_default_params, typename T>
3001
+ class args_selector;
3002
+
3003
+ /// args_selector is a helper class for extracting arguments from an
3004
+ /// array of pointers to arguments or buffer of arguments to pass to a
3005
+ /// kernel function.
3006
+ ///
3007
+ /// \param R(Ts...) The type of the kernel
3008
+ /// \param n_nondefault_params The number of nondefault parameters of the
3009
+ /// kernel (excluding parameters that like sycl::nd_item, etc.) \param
3010
+ /// n_default_params The number of default parameters of the kernel
3011
+ ///
3012
+ /// Example usage:
3013
+ /// With the following kernel:
3014
+ /// void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float
3015
+ /// f=.1) {}
3016
+ /// and with the declaration:
3017
+ /// args_selector<2, 1, decltype(foo)> selector(kernelParams, extra);
3018
+ /// we have:
3019
+ /// selector.get<0>() returns a reference to sycl::float*,
3020
+ /// selector.get<1>() returns a reference to int,
3021
+ /// selector.get<2>() returns a reference to float
3022
+ template <int n_nondefault_params, int n_default_params, typename R,
3023
+ typename... Ts>
3024
+ class args_selector<n_nondefault_params, n_default_params, R(Ts...)> {
3025
+ private:
3026
+ void **kernel_params;
3027
+ char *args_buffer;
3028
+
3029
+ template <int i> static constexpr int account_for_default_params() {
3030
+ constexpr int n_total_params = sizeof...(Ts);
3031
+ if constexpr (i >= n_nondefault_params) {
3032
+ return n_total_params - n_default_params +
3033
+ (i - n_nondefault_params);
3034
+ } else {
3035
+ return i;
3036
+ }
3037
+ }
3038
+
3039
+ public:
3040
+ /// Get the type of the ith argument of R(Ts...)
3041
+ /// \param [in] i Index of parameter to get
3042
+ /// \returns Type of ith parameter
3043
+ template <int i>
3044
+ using arg_type = std::tuple_element_t<account_for_default_params<i>(),
3045
+ std::tuple<Ts...>>;
3046
+ static constexpr int params_num = sizeof...(Ts);
3047
+
3048
+ private:
3049
+ template <int i> static constexpr int get_offset() {
3050
+ if constexpr (i == 0) {
3051
+ // we can assume args_buffer is properly aligned to the
3052
+ // first argument
3053
+ return 0;
3054
+ } else {
3055
+ constexpr int prev_off = get_offset<i - 1>();
3056
+ constexpr int prev_past_end =
3057
+ prev_off + sizeof(arg_type<i - 1>);
3058
+ using T = arg_type<i>;
3059
+ // is the past-the-end of the i-1st element properly aligned
3060
+ // with the ith element's alignment?
3061
+ if constexpr (prev_past_end % alignof(T) == 0) {
3062
+ return prev_past_end;
3063
+ }
3064
+ // otherwise bump prev_past_end to match alignment
3065
+ else {
3066
+ return prev_past_end +
3067
+ (alignof(T) - (prev_past_end % alignof(T)));
3068
+ }
3069
+ }
3070
+ }
3071
+
3072
+ static char *get_args_buffer(void **extra) {
3073
+ if (!extra)
3074
+ return nullptr;
3075
+ for (; (std::size_t)*extra != 0; ++extra) {
3076
+ if ((std::size_t)*extra == 1) {
3077
+ return static_cast<char *>(*(extra + 1));
3078
+ }
3079
+ }
3080
+ return nullptr;
3081
+ }
3082
+
3083
+ public:
3084
+ /// If kernel_params is nonnull, then args_selector will
3085
+ /// extract arguments from kernel_params. Otherwise, it
3086
+ /// will extract them from extra.
3087
+ /// \param [in] kernel_params Array of pointers to arguments
3088
+ /// a or null pointer.
3089
+ /// \param [in] extra Array containing pointer to argument buffer.
3090
+ args_selector(void **kernel_params, void **extra)
3091
+ : kernel_params(kernel_params),
3092
+ args_buffer(get_args_buffer(extra)) {}
3093
+
3094
+ /// Get a reference to the ith argument extracted from kernel_params
3095
+ /// or extra.
3096
+ /// \param [in] i Index of argument to get
3097
+ /// \returns Reference to the ith argument
3098
+ template <int i> arg_type<i> &get() {
3099
+ if (kernel_params) {
3100
+ return *static_cast<arg_type<i> *>(kernel_params[i]);
3101
+ } else {
3102
+ return *reinterpret_cast<arg_type<i> *>(args_buffer +
3103
+ get_offset<i>());
3104
+ }
3105
+ }
3106
+ }; // COPY from DPCT head file
3107
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
3108
+
3109
+ /// Utility class for launching SYCL kernels through kernel
3110
+ /// function wrapper.
3111
+ /// For example:
3112
+ /// A SYCL kernel function:
3113
+ /// void kernel_func(int *ptr, sycl::nd_item<3> item);
3114
+ /// Kernel function wrapper:
3115
+ /// void kernel_func_wrapper(int *ptr) {
3116
+ /// sycl::queue queue = *dpct::kernel_launcher::_que;
3117
+ /// unsigned int localMemSize = dpct::kernel_launcher::_local_mem_size;
3118
+ /// sycl::nd_range<3> nr = dpct::kernel_launcher::_nr;
3119
+ /// queue.parallel_for(
3120
+ /// nr,
3121
+ /// [=](sycl::nd_item<3> item_ct1) {
3122
+ /// kernel_func(ptr, item_ct1);
3123
+ /// });
3124
+ /// }
3125
+ /// Then launch the kernel through wrapper like:
3126
+ /// typedef void(*fpt)(int *);
3127
+ /// fpt fp = kernel_func_wrapper;
3128
+ /// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0,
3129
+ /// device_ptr);
3130
+ /// If the origin function type is erased, then need to register it first:
3131
+ /// void *fp = (void *)wrapper_register(&kernel_func_wrapper).get();
3132
+ /// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), args,
3133
+ /// 0, 0);
3134
+ class kernel_launcher {
3135
+ template <typename FuncT, typename ArgSelector, std::size_t... Index>
3136
+ static void launch_helper(FuncT &&func, ArgSelector &selector,
3137
+ std::index_sequence<Index...>) {
3138
+ func(selector.template get<Index>()...);
3139
+ }
3140
+ static void set_execution_config(dim3 group_range, dim3 local_range,
3141
+ unsigned int local_mem_size,
3142
+ queue_ptr que) {
3143
+ if (que) {
3144
+ _que = que;
3145
+ } else {
3146
+ _que = &get_default_queue();
3147
+ }
3148
+ _nr = sycl::nd_range<3>(
3149
+ static_cast<sycl::range<3>>(group_range * local_range),
3150
+ static_cast<sycl::range<3>>(local_range));
3151
+ _local_mem_size = local_mem_size;
3152
+
3153
+
3154
+ };
3155
+ static inline std::mutex kernel_function_ptr_map_mutex;
3156
+
3157
+ public:
3158
+ /// Variables for storing execution configuration.
3159
+ static inline thread_local sycl::queue *_que = nullptr;
3160
+ static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>();
3161
+ static inline thread_local unsigned int _local_mem_size = 0;
3162
+ /// Map for retrieving launchable functor from a raw pointer.
3163
+ static inline std::map<
3164
+ const void *,
3165
+ std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>>
3166
+ kernel_function_ptr_map = {};
3167
+
3168
+ /// Registers a kernel function pointer with a corresponding launchable
3169
+ /// functor.
3170
+ /// \param [in] func Pointer to the kernel function.
3171
+ /// \param [in] launcher Functor to handle kernel invocation.
3172
+ static void register_kernel_ptr(
3173
+ const void *func,
3174
+ std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>
3175
+ launcher) {
3176
+ std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
3177
+ kernel_function_ptr_map[func] = std::move(launcher);
3178
+ }
3179
+ /// Launches a kernel function with arguments provided directly through
3180
+ /// kernel function wrapper.
3181
+ /// \tparam FuncT Type of the kernel function wrapper.
3182
+ /// \tparam ArgsT Types of kernel arguments.
3183
+ /// \param [in] func Pointer to the kernel function wrapper.
3184
+ /// \param [in] group_range SYCL group range.
3185
+ /// \param [in] local_range SYCL local range.
3186
+ /// \param [in] local_mem_size The size of local memory required by the
3187
+ /// kernel function. \param [in] que SYCL queue used to execute kernel.
3188
+ /// \param [in] args Kernel arguments.
3189
+ template <typename FuncT, typename... ArgsT>
3190
+ static std::enable_if_t<std::is_invocable_v<FuncT *, ArgsT...>, void>
3191
+ launch(FuncT *func, dim3 group_range, dim3 local_range,
3192
+ unsigned int local_mem_size, queue_ptr que, ArgsT... args) {
3193
+ set_execution_config(group_range, local_range, local_mem_size, que);
3194
+ func(args...);
3195
+ }
3196
+ /// Launches a kernel function through registered kernel function
3197
+ /// wrapper. \param [in] func Pointer to the registered kernel function
3198
+ /// wrapper. \param [in] group_range SYCL group range. \param [in]
3199
+ /// local_range SYCL local range. \param [in] args Array of pointers to
3200
+ /// kernel arguments. \param [in] local_mem_size The size of local
3201
+ /// memory required by the kernel function. \param [in] que SYCL queue
3202
+ /// used to execute kernel.
3203
+ static void launch(const void *func, dim3 group_range, dim3 local_range,
3204
+ void **args, unsigned int local_mem_size,
3205
+ queue_ptr que) {
3206
+ std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
3207
+ auto Iter = kernel_function_ptr_map.find(func);
3208
+ if (Iter == kernel_function_ptr_map.end()) {
3209
+ throw std::runtime_error("dpct::launch() : no registered "
3210
+ "kernel function wrapper found.");
3211
+ }
3212
+ (Iter->second)(group_range, local_range, args, local_mem_size, que);
3213
+ }
3214
+ /// Launches a kernel function with packed arguments through kernel
3215
+ /// function wrapper.
3216
+ /// \tparam FuncT Type of the kernel function wrapper.
3217
+ /// \param [in] func Pointer to the kernel function wrapper.
3218
+ /// \param [in] group_range SYCL group range.
3219
+ /// \param [in] local_range SYCL local range.
3220
+ /// \param [in] args Array of pointers to kernel arguments.
3221
+ /// \param [in] local_mem_size The size of local memory required by the
3222
+ /// kernel function. \param [in] que SYCL queue used to execute kernel.
3223
+ template <typename FuncT>
3224
+ static std::enable_if_t<std::is_function_v<FuncT>, void>
3225
+ launch(FuncT *func, dim3 group_range, dim3 local_range, void **args,
3226
+ unsigned int local_mem_size, queue_ptr que) {
3227
+ constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num;
3228
+ set_execution_config(group_range, local_range, local_mem_size, que);
3229
+ args_selector<p_num, p_num, FuncT> selector(args, nullptr);
3230
+ launch_helper(func, selector, std::make_index_sequence<p_num>{});
3231
+ }
3232
+ }; // COPY from DPCT head file
3233
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/kernel.hpp
3234
+
3235
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
3236
+ template <typename T>
3237
+ T select_from_sub_group(
3238
+ sycl::sub_group g,
3239
+ T x,
3240
+ int remote_local_id,
3241
+ int logical_sub_group_size = 32) {
3242
+ unsigned int start_index = g.get_local_linear_id() /
3243
+ logical_sub_group_size *
3244
+ logical_sub_group_size;
3245
+ return sycl::select_from_group(
3246
+ g, x, start_index + remote_local_id % logical_sub_group_size);
3247
+ }
3248
+
3249
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
3250
+ template <typename T>
3251
+ void ldmatrix(uintptr_t addr, T* m, bool trans = false, unsigned mat = 0) {
3252
+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
3253
+ int lane = sg.get_local_linear_id();
3254
+
3255
+ int lane_group8_row = lane / 8;
3256
+ int lane_group8_col = lane % 8;
3257
+
3258
+ if (!trans) {
3259
+ // calculate the source lane
3260
+ int src_lane = 2 * lane_group8_row;
3261
+ if (lane_group8_col >= 4)
3262
+ src_lane += 1;
3263
+
3264
+ // Broadcast the address from the source lane
3265
+ auto recv_addr_uintp =
3266
+ dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
3267
+
3268
+ // Cast the received address from uintptr_t to the type of 'm'
3269
+ auto recv_addr = reinterpret_cast<T*>(recv_addr_uintp);
3270
+
3271
+ // Non-transposed load
3272
+ *m = recv_addr[lane_group8_col % 4];
3273
+ } else {
3274
+ // calculate the source lane
3275
+ int src_lane = (lane % 4) * 2;
3276
+
3277
+ // Broadcast the address from the source lane
3278
+ auto recv_addr_uintp_1 =
3279
+ dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
3280
+ auto recv_addr_uintp_2 =
3281
+ dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
3282
+
3283
+ // Cast the received address from uintptr_t to 'half *'
3284
+ auto recv_addr_1 = reinterpret_cast<sycl::half*>(recv_addr_uintp_1);
3285
+ auto recv_addr_2 = reinterpret_cast<sycl::half*>(recv_addr_uintp_2);
3286
+
3287
+ // Transposed load
3288
+ int index = lane / 4;
3289
+ sycl::half val0 = recv_addr_1[index];
3290
+ sycl::half val1 = recv_addr_2[index];
3291
+
3292
+ // Combine the two 16-bits into one 32-bit value
3293
+ sycl::half2 val = sycl::half2(val0, val1);
3294
+ *m = *reinterpret_cast<T*>(&val);
3295
+ }
3296
+ }
3297
+
3298
+ template <typename T>
3299
+ void ldmatrix(uintptr_t addr, T* m1, T* m2, bool trans = false) {
3300
+ // Load 1st matrix
3301
+ ldmatrix(addr, m1, trans, 0);
3302
+ // Load 2nd matrix
3303
+ ldmatrix(addr, m2, trans, 1);
3304
+ }
3305
+
3306
+ template <typename T>
3307
+ void ldmatrix(
3308
+ uintptr_t addr, T* m1, T* m2, T* m3, T* m4, bool trans = false) {
3309
+ // Load 1st matrix
3310
+ ldmatrix(addr, m1, trans, 0);
3311
+ // Load 2nd matrix
3312
+ ldmatrix(addr, m2, trans, 1);
3313
+ // Load 3rd matrix
3314
+ ldmatrix(addr, m3, trans, 2);
3315
+ // Load 4th matrix
3316
+ ldmatrix(addr, m4, trans, 3);
3317
+ }
3318
+
3319
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
3320
+
3321
+ /// A helper struct that defines the pack type for the input matrix
3322
+ /// fragments
3323
+ /// of mma() function based on the type of input matrix fragments.
3324
+ /// The MMAType struct is specialized for different types of input matrices.
3325
+ /// Currently, the specialization for f16, bf16 and s8 types is defined
3326
+ /// below. \tparam [in] T The type of the input matrix fragments
3327
+ template <typename T>
3328
+ struct MMAType {
3329
+ using PackType = uint32_t;
3330
+ };
3331
+
3332
+ /// Each work item of a sub-group (limited to size 32) calling this function
3333
+ /// calculates a subset fragment for the output matrix D using MAD operation
3334
+ /// on A, B & C matrix fragments (D = A * B + C). Current supported shapes &
3335
+ /// types:
3336
+ /// - m8n8k4 (f32.f16.f16.f32)
3337
+ /// - m8n8k16 (s32.s8.s8.s32)
3338
+ /// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
3339
+ /// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32)
3340
+ /// - m16n8k32 (s32.s8.s8.s32)
3341
+ /// Here, m, n & k define the shapes of A, B & C matrices respectively
3342
+ /// (A = [m x k], B = [k x n], C = [m x n]).
3343
+ /// \tparam [in] M The rows of A, C & D matrices
3344
+ /// \tparam [in] N The columns of B, C, D matrices
3345
+ /// \tparam [in] K The columns & rows of A & B matrices respectively
3346
+ /// \tparam [in] ABType The type of the input matrix (A & B) fragment
3347
+ /// \tparam [in] CDType The type of the output matrix (C & D) fragment
3348
+ /// \param [out] d_mat_frag The fragment of the output matrix D to store the
3349
+ /// result of A * B + C
3350
+ /// \param [in] a_mat_frag The fragment of the input matrix A to be
3351
+ /// multiplied with B matrix fragment \param [in] b_mat_frag The fragment of
3352
+ /// the input matrix B to be multiplied with A matrix fragment \param [in]
3353
+ /// c_mat_frag The fragment of the input matrix C to be added with the
3354
+ /// result of A * B fragments
3355
+ template <int M, int N, int K, typename ABType, typename CDType>
3356
+ void mma(
3357
+ volatile void** d_mat_frag,
3358
+ void* a_mat_frag,
3359
+ void* b_mat_frag,
3360
+ void* c_mat_frag) {
3361
+ auto d = reinterpret_cast<volatile CDType**>(d_mat_frag);
3362
+ auto a =
3363
+ reinterpret_cast<typename MMAType<ABType>::PackType*>(a_mat_frag);
3364
+ auto b =
3365
+ reinterpret_cast<typename MMAType<ABType>::PackType*>(b_mat_frag);
3366
+ auto c = reinterpret_cast<CDType*>(c_mat_frag);
3367
+
3368
+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
3369
+ int lane = sg.get_local_linear_id();
3370
+
3371
+ static_assert(
3372
+ (M == 8 && N == 8 && K == 4) || (M == 8 && N == 8 && K == 16) ||
3373
+ (M == 16 && N == 8 && K == 8) || (M == 16 && N == 8 && K == 16) ||
3374
+ (M == 16 && N == 8 && K == 32),
3375
+ "Unsupported MMA shape!");
3376
+
3377
+ short row_load_offset = 4 * (lane >> 2);
3378
+ short col_load_offset = 8 * (lane % 4);
3379
+
3380
+ if constexpr (M == 8 && N == 8 && K == 4) {
3381
+ if constexpr (std::is_floating_point_v<CDType>) {
3382
+ col_load_offset = row_load_offset % 16;
3383
+
3384
+ // Init D matrix with fragments of C matrix
3385
+ *d[0] = c[0];
3386
+ *d[1] = c[1];
3387
+ *d[2] = c[2];
3388
+ *d[3] = c[3];
3389
+ *d[4] = c[4];
3390
+ *d[5] = c[5];
3391
+ *d[6] = c[6];
3392
+ *d[7] = c[7];
3393
+
3394
+ // Calculate the row and col offset indices to iterate through the row
3395
+ // & col fragments of A & B matrices
3396
+ int r_ind = (lane % 2) ? 1 : 0;
3397
+ int c_ind = ((lane % 4) / 2) ? 2 : 0;
3398
+
3399
+ // Each sub-group is responsible for computing a fragment size of 8*8
3400
+ // elements of matrix D for each of 4 MMA computations.
3401
+ // Each work item computes 8 elements of matrix D by gathering
3402
+ // their corresponding col & row matrix fragments of length k (4)
3403
+ // from A & B matrices respectively using below mapping logic:
3404
+ // row0 = (i % 4) if (lane < 16) else (i % 4) + 4
3405
+ // col0 = (lane % 4)
3406
+ // As each row & col fragment of A & B matrices is distributed across
3407
+ // 4 work items, each iteration of below loop loads a partial fragment
3408
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3409
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3410
+
3411
+ for (int i = 0; i < 4; i++) {
3412
+ // Load partial fragment from col0 of matrix A ({a0, a1})
3413
+ recv_a[0] =
3414
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3415
+ // Load partial fragment from col0 of matrix A ({a2, a3})
3416
+ recv_a[1] =
3417
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3418
+
3419
+ // Load partial fragment from row0 of matrix B ({b0, b1})
3420
+ recv_b[0] =
3421
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3422
+ // Load partial fragment from row0 of matrix B ({b2, b3})
3423
+ recv_b[1] =
3424
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
3425
+
3426
+ auto ra = reinterpret_cast<ABType*>(recv_a);
3427
+ auto rb = reinterpret_cast<ABType*>(recv_b);
3428
+
3429
+ // Each work item calculates a partial product of A & B matrix
3430
+ // fragments and adds it to the corresponding D matrix fragment (for
3431
+ // even work item indices) d0 += col0{ a0 } * row0{ b0 } d1 += col0{
3432
+ // a0 } * row0{ b1 } d2 += col1{ a2 } * row0{ b0 } d3 += col1{ a2 }
3433
+ // * row0{ b1 } (for odd work item indices) d0 += col0{ a1 } * row0{
3434
+ // b2 } d1 += col0{ a1 } * row0{ b3 } d2 += col1{ a3 } * row0{ b2 }
3435
+ // d3 += col1{ a3 } * row0{ b3 }
3436
+ *d[0] +=
3437
+ static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
3438
+ *d[1] += static_cast<float>(ra[r_ind]) *
3439
+ static_cast<float>(rb[c_ind + 1]);
3440
+ *d[2] += static_cast<float>(ra[r_ind + 2]) *
3441
+ static_cast<float>(rb[c_ind]);
3442
+ *d[3] += static_cast<float>(ra[r_ind + 2]) *
3443
+ static_cast<float>(rb[c_ind + 1]);
3444
+
3445
+ // Load partial fragment from row1 of matrix B ({b0, b1})
3446
+ recv_b[0] =
3447
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 16);
3448
+ // Load partial fragment from row1 of matrix B ({b2, b3})
3449
+ recv_b[1] =
3450
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 16);
3451
+
3452
+ // (for even work item indices)
3453
+ // d0 += col0{ a0 } * row1{ b0 }
3454
+ // d1 += col0{ a0 } * row1{ b1 }
3455
+ // d2 += col1{ a2 } * row1{ b0 }
3456
+ // d3 += col1{ a2 } * row1{ b1 }
3457
+ // (for odd work item indices)
3458
+ // d0 += col0{ a1 } * row1{ b2 }
3459
+ // d1 += col0{ a1 } * row1{ b3 }
3460
+ // d2 += col1{ a3 } * row1{ b2 }
3461
+ // d3 += col1{ a3 } * row1{ b3 }
3462
+ *d[4] +=
3463
+ static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
3464
+ *d[5] += static_cast<float>(ra[r_ind]) *
3465
+ static_cast<float>(rb[c_ind + 1]);
3466
+ *d[6] += static_cast<float>(ra[r_ind + 2]) *
3467
+ static_cast<float>(rb[c_ind]);
3468
+ *d[7] += static_cast<float>(ra[r_ind + 2]) *
3469
+ static_cast<float>(rb[c_ind + 1]);
3470
+ }
3471
+ }
3472
+ } else if constexpr (M == 8 && N == 8 && K == 16) {
3473
+ if constexpr (std::is_integral_v<ABType>) {
3474
+ // Init D matrix with fragments of C matrix
3475
+ *d[0] = c[0];
3476
+ *d[1] = c[1];
3477
+
3478
+ // Each sub-group is responsible for computing a fragment size of 16*8
3479
+ // elements of matrix D.
3480
+ // Each work item computes 2 elements of matrix D by gathering
3481
+ // their corresponding row & col matrix fragments of length k (16)
3482
+ // from A & B matrices respectively using below mapping logic:
3483
+ // row0 = ((lane % 4) * 4) + i
3484
+ // col0 = (lane >> 2)
3485
+ // As each row & col fragment of A & B matrices is distributed across
3486
+ // 4 work items, each iteration of below loop loads a partial fragment
3487
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3488
+ for (int i = 0; i < 4; i++) {
3489
+ typename MMAType<ABType>::PackType recv_a, recv_b[2];
3490
+
3491
+ // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
3492
+ recv_a = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3493
+ // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
3494
+ recv_b[0] =
3495
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3496
+ // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
3497
+ recv_b[1] =
3498
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
3499
+
3500
+ auto a = reinterpret_cast<ABType*>(&recv_a);
3501
+ auto b = reinterpret_cast<ABType*>(recv_b);
3502
+
3503
+ // Each work item calculates a partial product of A & B matrix
3504
+ // fragments and adds it to the corresponding D matrix fragment d0
3505
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
3506
+ // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row0{ a0, a1, a2,
3507
+ // a3 } * col0{ b0, b1, b2, b3 } d3 += row0{ a0, a1, a2, a3 } *
3508
+ // col1{ b0, b1, b2, b3 }
3509
+ for (int j = 0; j < 4; j++) {
3510
+ *d[0] += a[j] * b[j];
3511
+ *d[1] += a[j] * b[j + 4];
3512
+ }
3513
+ }
3514
+ }
3515
+ } else if constexpr (M == 16 && N == 8 && K == 8) {
3516
+ if constexpr (std::is_floating_point_v<CDType>) {
3517
+ // Init D matrix fragment with C matrix fragment
3518
+ *d[0] = c[0];
3519
+ *d[1] = c[1];
3520
+ *d[2] = c[2];
3521
+ *d[3] = c[3];
3522
+
3523
+ // Each sub-group is responsible for computing a fragment size of 16*8
3524
+ // elements of matrix D.
3525
+ // Each work item computes 4 elements of matrix D by gathering
3526
+ // their corresponding row & col matrix fragments of length k (8)
3527
+ // from A & B matrices respectively using below mapping logic:
3528
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
3529
+ // col0 = (lane % 4) * 2 + (i & 0x1)
3530
+ // As each row & col fragment of A & B matrices is distributed across
3531
+ // 4 work items, each iteration of below loop loads a partial fragment
3532
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3533
+ for (int i = 0; i < 4; i++) {
3534
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3535
+
3536
+ // Load partial fragment from row0 of matrix A ({a0, a1})
3537
+ recv_a[0] =
3538
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3539
+ // Load partial fragment from row1 of matrix A ({a2, a3})
3540
+ recv_a[1] =
3541
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3542
+ // Load partial fragment from col0 of matrix B ({b0, b1})
3543
+ recv_b[0] =
3544
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3545
+ // Load partial fragment from col1 of matrix B ({b0, b1})
3546
+ recv_b[1] =
3547
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
3548
+
3549
+ auto ra = reinterpret_cast<ABType*>(recv_a);
3550
+ auto rb = reinterpret_cast<ABType*>(recv_b);
3551
+
3552
+ // Each work item calculates a partial product of A & B matrix
3553
+ // fragments and adds it to the corresponding D matrix fragment d0
3554
+ // += row0{ a0, a1 } * col0{ b0, b1 } d1 += row0{ a0, a1 } * col1{
3555
+ // b0, b1 } d2 += row1{ a2, a3 } * col0{ b0, b1 } d3 += row1{ a2, a3
3556
+ // } * col1{ b0, b1 }
3557
+ for (int j = 0; j < 2; j++) {
3558
+ *d[0] += static_cast<float>(ra[j]) * static_cast<float>(rb[j]);
3559
+ *d[1] +=
3560
+ static_cast<float>(ra[j]) * static_cast<float>(rb[j + 2]);
3561
+ *d[2] +=
3562
+ static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j]);
3563
+ *d[3] +=
3564
+ static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j + 2]);
3565
+ }
3566
+ }
3567
+ }
3568
+ } else if constexpr (M == 16 && N == 8 && K == 16) {
3569
+ if constexpr (std::is_floating_point_v<CDType>) {
3570
+ // Init D matrix fragment with C matrix fragment
3571
+ *d[0] = c[0];
3572
+ *d[1] = c[1];
3573
+ *d[2] = c[2];
3574
+ *d[3] = c[3];
3575
+
3576
+ // Each sub-group is responsible for computing a fragment size of 16*8
3577
+ // elements of matrix D.
3578
+ // Each work item computes 4 elements of matrix D by gathering
3579
+ // their corresponding row & col matrix fragments of length k (8)
3580
+ // from A & B matrices respectively using below mapping logic:
3581
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
3582
+ // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
3583
+ // As each row & col fragment of A & B matrices is distributed across
3584
+ // 4 work items, each iteration of below loop loads a partial fragment
3585
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3586
+ for (int i = 0; i < 4; i++) {
3587
+ typename MMAType<ABType>::PackType recv_a[4], recv_b[4];
3588
+
3589
+ // Load partial fragment from row0 of matrix A ({a0, a1})
3590
+ recv_a[0] =
3591
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3592
+ // Load partial fragment from row0 of matrix A ({a2, a3})
3593
+ recv_a[1] =
3594
+ dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
3595
+ // Load partial fragment from row1 of matrix A ({a0, a1})
3596
+ recv_a[2] =
3597
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3598
+ // Load partial fragment from row1 of matrix A ({a2, a3})
3599
+ recv_a[3] =
3600
+ dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
3601
+
3602
+ // Load partial fragment from col0 of matrix B ({b0, b1})
3603
+ recv_b[0] =
3604
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3605
+ // Load partial fragment from col0 of matrix B ({b2, b3})
3606
+ recv_b[1] =
3607
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
3608
+ // Load partial fragment from col1 of matrix B ({b0, b1})
3609
+ recv_b[2] =
3610
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);
3611
+ // Load partial fragment from col1 of matrix B ({b2, b3})
3612
+ recv_b[3] =
3613
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);
3614
+
3615
+ auto ra = reinterpret_cast<ABType*>(recv_a);
3616
+ auto rb = reinterpret_cast<ABType*>(recv_b);
3617
+
3618
+ // Each work item calculates a partial product of A & B matrix
3619
+ // fragments and adds it to the corresponding D matrix fragment d0
3620
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
3621
+ // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, a2,
3622
+ // a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } *
3623
+ // col1{ b0, b1, b2, b3 }
3624
+ for (int j = 0; j < 4; j++) {
3625
+ *d[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
3626
+ *d[1] +=
3627
+ static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]);
3628
+ *d[2] +=
3629
+ static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]);
3630
+ *d[3] += static_cast<CDType>(ra[j + 4]) *
3631
+ static_cast<CDType>(rb[j + 4]);
3632
+ }
3633
+ }
3634
+ } else if constexpr (std::is_integral_v<ABType>) {
3635
+ // Init D matrix with fragments of C matrix
3636
+ *d[0] = c[0];
3637
+ *d[1] = c[1];
3638
+ *d[2] = c[2];
3639
+ *d[3] = c[3];
3640
+
3641
+ // Each sub-group is responsible for computing a fragment size of 16*8
3642
+ // elements of matrix D.
3643
+ // Each work item computes 4 elements of matrix D by gathering
3644
+ // their corresponding row & col matrix fragments of length k (8)
3645
+ // from A & B matrices respectively using below mapping logic:
3646
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
3647
+ // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
3648
+ // As each row & col fragment of A & B matrices is distributed across
3649
+ // 4 work items, each iteration of below loop loads a partial fragment
3650
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
3651
+ for (int i = 0; i < 4; i++) {
3652
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3653
+
3654
+ // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
3655
+ recv_a[0] =
3656
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3657
+ // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
3658
+ recv_a[1] =
3659
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3660
+ // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
3661
+ recv_b[0] =
3662
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3663
+ // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
3664
+ recv_b[1] =
3665
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
3666
+
3667
+ auto ra = reinterpret_cast<ABType*>(recv_a);
3668
+ auto rb = reinterpret_cast<ABType*>(recv_b);
3669
+
3670
+ // Each work item calculates a partial product of A & B matrix
3671
+ // fragments and adds it to the corresponding D matrix fragment d0
3672
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
3673
+ // a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } d2 += row1{ a4, a5, a6,
3674
+ // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
3675
+ // col1{ b4, b5, b6, b7 }
3676
+ for (int i = 0; i < 4; i++) {
3677
+ *d[0] += ra[i] * rb[i];
3678
+ *d[1] += ra[i] * rb[i + 4];
3679
+ *d[2] += ra[i + 4] * rb[i];
3680
+ *d[3] += ra[i + 4] * rb[i + 4];
3681
+ }
3682
+ }
3683
+ }
3684
+ } else if constexpr (M == 16 && N == 8 && K == 32) {
3685
+ if constexpr (std::is_integral_v<ABType>) {
3686
+ // Init D matrix with fragments of C matrix
3687
+ *d[0] = c[0];
3688
+ *d[1] = c[1];
3689
+ *d[2] = c[2];
3690
+ *d[3] = c[3];
3691
+
3692
+ // Each sub-group is responsible for computing a fragment size of 16*8
3693
+ // elements of matrix D.
3694
+ // Each work item computes 4 elements of matrix D by gathering
3695
+ // their corresponding row & col matrix fragments of length k (32)
3696
+ // from A & B matrices respectively using below mapping logic:
3697
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
3698
+ // col0 = ((lane % 4) * 4) + (i & 0x3) & col1 = ((lane % 4) * 4) + (i
3699
+ // & 0x3) As each row & col fragment of A & B matrices is distributed
3700
+ // across 4 work items, each iteration of below loop loads a partial
3701
+ // fragment of matrix A (row) and matrix B (col) using the row & col
3702
+ // offsets.
3703
+ for (int i = 0; i < 4; i++) {
3704
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3705
+
3706
+ // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
3707
+ recv_a[0] =
3708
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
3709
+ // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
3710
+ recv_a[1] =
3711
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
3712
+ // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
3713
+ recv_b[0] =
3714
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
3715
+ // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
3716
+ recv_b[1] =
3717
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
3718
+
3719
+ auto a = reinterpret_cast<ABType*>(recv_a);
3720
+ auto b = reinterpret_cast<ABType*>(recv_b);
3721
+
3722
+ // Each work item calculates a partial product of A & B matrix
3723
+ // fragments and adds it to the corresponding D matrix fragment d0
3724
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
3725
+ // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a4, a5, a6,
3726
+ // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
3727
+ // col1{ b0, b1, b2, b3 }
3728
+ for (int j = 0; j < 4; j++) {
3729
+ *d[0] += a[j] * b[j];
3730
+ *d[1] += a[j] * b[j + 4];
3731
+ *d[2] += a[j + 4] * b[j];
3732
+ *d[3] += a[j + 4] * b[j + 4];
3733
+ }
3734
+ }
3735
+
3736
+ for (int i = 0; i < 4; i++) {
3737
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
3738
+
3739
+ // Load partial fragment from row0 of matrix A ({a8, a9, a10, a11})
3740
+ recv_a[0] =
3741
+ dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
3742
+ // Load partial fragment from row1 of matrix A ({a12, a13, a14,
3743
+ // a15})
3744
+ recv_a[1] =
3745
+ dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
3746
+ // Load partial fragment from col0 of matrix B ({b4, b5, b6, b7})
3747
+ recv_b[0] =
3748
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
3749
+ // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
3750
+ recv_b[1] =
3751
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 4);
3752
+
3753
+ auto a = reinterpret_cast<ABType*>(recv_a);
3754
+ auto b = reinterpret_cast<ABType*>(recv_b);
3755
+
3756
+ // Each work item calculates a partial product of A & B matrix
3757
+ // fragments and adds it to the corresponding D matrix fragment d0
3758
+ // += row0{ a8, a9, a10, a11 } * col0{ b4, b5, b6, b7 } d1 += row0{
3759
+ // a8, a9, a10, a11 } * col1{ b4, b5, b6, b7 } d2 += row1{ a12, a13,
3760
+ // a14, a15 } * col0{ b4, b5, b6, b7 } d3 += row1{ a12, a13, a14,
3761
+ // a15 } * col1{ b4, b5, b6, b7 }
3762
+ for (int j = 0; j < 4; j++) {
3763
+ *d[0] += a[j] * b[j];
3764
+ *d[1] += a[j] * b[j + 4];
3765
+ *d[2] += a[j + 4] * b[j];
3766
+ *d[3] += a[j + 4] * b[j + 4];
3767
+ }
3768
+ }
3769
+ }
3770
+ }
3771
+ }
3028
3772
  } // COPY from DPCT head files
3029
3773
 
3030
3774
  #endif // GGML_SYCL_DPCT_HELPER_HPP