whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -151,72 +151,50 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
151
151
  }
152
152
 
153
153
  template<typename T>
154
- static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
155
- SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
156
- dst[i] = op_sgn(x[i]);
157
- }
154
+ static __dpct_inline__ T op_floor(T x) {
155
+ return sycl::floor(x);
158
156
  }
159
157
 
160
158
  template<typename T>
161
- static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
162
- SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
163
- dst[i] = op_abs(x[i]);
164
- }
159
+ static __dpct_inline__ T op_ceil(T x) {
160
+ return sycl::ceil(x);
165
161
  }
166
162
 
167
163
  template<typename T>
168
- static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
169
- SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
170
- dst[i] = op_elu(x[i]);
171
- }
164
+ static __dpct_inline__ T op_round(T x) {
165
+ return sycl::round(x);
172
166
  }
173
167
 
174
168
  template<typename T>
175
- static void unary_op_gelu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
169
+ static __dpct_inline__ T op_trunc(T x) {
170
+ return sycl::trunc(x);
171
+ }
172
+
173
+ template<typename T, typename F>
174
+ static void unary_op_generic_kernel(
175
+ const T * x,
176
+ T * dst,
177
+ const int k,
178
+ const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3,
179
+ const size_t nb0, const size_t nb1, const size_t nb2, const size_t nb3,
180
+ const size_t nbd0, const size_t nbd1, const size_t nbd2, const size_t nbd3,
181
+ const sycl::nd_item<1> & item_ct1,
182
+ F func) {
183
+
184
+ (void) ne3;
176
185
  SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
177
- dst[i] = op_gelu(x[i]);
178
- }
179
- }
186
+ const int64_t i0 = i % ne0;
187
+ const int64_t i1 = (i / ne0) % ne1;
188
+ const int64_t i2 = (i / (ne0*ne1)) % ne2;
189
+ const int64_t i3 = i / (ne0*ne1*ne2);
180
190
 
181
- template<typename T>
182
- static void unary_op_silu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
183
- SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
184
- dst[i] = op_silu(x[i]);
185
- }
186
- }
191
+ const char * src_base = (const char *) x;
192
+ char * dst_base = (char *) dst;
187
193
 
188
- template<typename T>
189
- static void unary_op_gelu_quick_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
190
- SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
191
- dst[i] = op_gelu_quick(x[i]);
192
- }
193
- }
194
+ const T * srcp = (const T *)(src_base + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3 );
195
+ T * dstp = (T *)(dst_base + i0*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3);
194
196
 
195
- template<typename T>
196
- static void unary_op_gelu_erf_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
197
- SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
198
- dst[i] = op_gelu_erf(x[i]);
199
- }
200
- }
201
-
202
- template<typename T>
203
- static void unary_op_tanh_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
204
- SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
205
- dst[i] = op_tanh(x[i]);
206
- }
207
- }
208
-
209
- template<typename T>
210
- static void unary_op_relu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
211
- SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
212
- dst[i] = op_relu(x[i]);
213
- }
214
- }
215
-
216
- template<typename T>
217
- static void unary_op_sigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
218
- SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
219
- dst[i] = op_sigmoid(x[i]);
197
+ *dstp = func(*srcp);
220
198
  }
221
199
  }
222
200
 
@@ -242,65 +220,59 @@ static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::n
242
220
  }
243
221
 
244
222
  template<typename T>
245
- static void unary_op_hardsigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
223
+ static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
246
224
  SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
247
- dst[i] = op_hardsigmoid(x[i]);
225
+ dst[i] = op_log(x[i]);
248
226
  }
249
227
  }
250
228
 
251
- template<typename T>
252
- static void unary_op_hardswish_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
253
- SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
254
- dst[i] = op_hardswish(x[i]);
255
- }
256
- }
257
229
 
258
230
  template<typename T>
259
- static void unary_op_exp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
231
+ static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
260
232
  SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
261
- dst[i] = op_exp(x[i]);
233
+ dst[i] = op_leaky_relu(x[i], negative_slope);
262
234
  }
263
235
  }
264
236
 
265
237
  template<typename T>
266
- static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
238
+ static void unary_op_sqr_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
267
239
  SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
268
- dst[i] = op_log(x[i]);
240
+ dst[i] = op_sqr(x[i]);
269
241
  }
270
242
  }
271
243
 
272
244
  template<typename T>
273
- static void unary_op_neg_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
245
+ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1, float min_val, float max_val) {
274
246
  SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
275
- dst[i] = op_neg(x[i]);
247
+ dst[i] = op_clamp(x[i], min_val, max_val);
276
248
  }
277
249
  }
278
250
 
279
251
  template<typename T>
280
- static void unary_op_step_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
252
+ static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
281
253
  SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
282
- dst[i] = op_step(x[i]);
254
+ dst[i] = op_floor(x[i]);
283
255
  }
284
256
  }
285
257
 
286
258
  template<typename T>
287
- static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
259
+ static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
288
260
  SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
289
- dst[i] = op_leaky_relu(x[i], negative_slope);
261
+ dst[i] = op_ceil(x[i]);
290
262
  }
291
263
  }
292
264
 
293
265
  template<typename T>
294
- static void unary_op_sqr_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
266
+ static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
295
267
  SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
296
- dst[i] = op_sqr(x[i]);
268
+ dst[i] = op_round(x[i]);
297
269
  }
298
270
  }
299
271
 
300
272
  template<typename T>
301
- static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1, float min_val, float max_val) {
273
+ static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
302
274
  SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
303
- dst[i] = op_clamp(x[i], min_val, max_val);
275
+ dst[i] = op_trunc(x[i]);
304
276
  }
305
277
  }
306
278
 
@@ -328,26 +300,6 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
328
300
  dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
329
301
  }
330
302
 
331
- template <typename T>
332
- static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
333
- const sycl::nd_item<3> &item_ct1) {
334
- int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
335
- if (nidx >= ne0) {
336
- return;
337
- }
338
-
339
- // operation
340
- int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
341
- item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
342
- if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
343
- int offset_src = nidx + item_ct1.get_group(1) * ne00 +
344
- item_ct1.get_group(0) * ne00 * ne01;
345
- dst[offset_dst] = x[offset_src];
346
- } else {
347
- dst[offset_dst] = static_cast<T>(0.0f);
348
- }
349
- }
350
-
351
303
  template<typename T>
352
304
  static void clamp(const T * x, T * dst, const float min, const float max, const int k,
353
305
  const sycl::nd_item<1> &item_ct1) {
@@ -417,6 +369,14 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
417
369
  });
418
370
  }
419
371
 
372
+ template<typename T>
373
+ static void arange_kernel(T * dst, const int k, T start, T step,
374
+ const sycl::nd_item<1> &item_ct1) {
375
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
376
+ dst[i] = start + static_cast<T>(i) * step;
377
+ }
378
+ }
379
+
420
380
  template<typename T>
421
381
  static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
422
382
  const int nb02, const int nb03, const int ne10, const int ne11,
@@ -431,18 +391,6 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
431
391
  });
432
392
  }
433
393
 
434
- template<typename T>
435
- static void pad_sycl(const T *x, T *dst, const int ne00,
436
- const int ne01, const int ne02, const int ne0,
437
- const int ne1, const int ne2, queue_ptr stream) {
438
- int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
439
- sycl::range<3> gridDim(ne2, ne1, num_blocks);
440
- stream->parallel_for(
441
- sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
442
- sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
443
- [=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
444
- }
445
-
446
394
  template<typename KernelInvoker, typename... Args>
447
395
  static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
448
396
  #if defined (GGML_SYCL_F16)
@@ -596,199 +544,142 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
596
544
  }
597
545
  }
598
546
 
599
- template<typename KernelInvoker, typename... Args>
600
- static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
601
- #if defined (GGML_SYCL_F16)
602
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
603
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
604
- #else
605
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
606
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
607
- #endif
608
- GGML_ASSERT(dst->src[0]->type == dst->type);
609
- GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
610
- dpct::queue_ptr main_stream = ctx.stream();
611
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
612
- switch (dst->type) {
613
- #if defined (GGML_SYCL_F16)
614
- case GGML_TYPE_F16:
615
- {
616
- auto data_pts = cast_data<sycl::half>(dst);
617
- kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
618
- (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
619
- break;
620
- }
621
- #endif
622
- case GGML_TYPE_F32:
623
- {
624
- auto data_pts = cast_data<float>(dst);
625
- kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
626
- (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
627
- break;
628
- }
629
- default:
630
- GGML_ABORT("GGML tensor type not supported!\n");
631
- }
632
- }
547
+ template<typename F>
548
+ static inline void ggml_sycl_op_unary(
549
+ ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) {
633
550
 
634
- } // namespace ggml_sycl_detail
551
+ ggml_tensor * src0 = dst->src[0];
635
552
 
553
+ const int64_t ne0 = dst->ne[0];
554
+ const int64_t ne1 = dst->ne[1];
555
+ const int64_t ne2 = dst->ne[2];
556
+ const int64_t ne3 = dst->ne[3];
636
557
 
558
+ const size_t nb0 = src0->nb[0];
559
+ const size_t nb1 = src0->nb[1];
560
+ const size_t nb2 = src0->nb[2];
561
+ const size_t nb3 = src0->nb[3];
562
+
563
+ const size_t nbd0 = dst->nb[0];
564
+ const size_t nbd1 = dst->nb[1];
565
+ const size_t nbd2 = dst->nb[2];
566
+ const size_t nbd3 = dst->nb[3];
637
567
 
638
- static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
639
568
  ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
640
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
569
+ [=](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
570
+
641
571
  const int num_blocks = ceil_div(k_elements, 256);
572
+
642
573
  stream->parallel_for(
643
574
  sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
644
575
  sycl::range<1>(256)),
645
576
  [=](sycl::nd_item<1> item_ct1) {
646
- unary_op_sgn_kernel(src, dst_ptr, k_elements, item_ct1);
577
+ unary_op_generic_kernel(
578
+ src, dst_ptr, k_elements,
579
+ ne0, ne1, ne2, ne3,
580
+ nb0, nb1, nb2, nb3,
581
+ nbd0, nbd1, nbd2, nbd3,
582
+ item_ct1,
583
+ func
584
+ );
647
585
  });
648
586
  });
649
587
  }
650
588
 
651
- static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
652
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
653
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
654
- const int num_blocks = ceil_div(k_elements, 256);
655
- stream->parallel_for(
656
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
657
- sycl::range<1>(256)),
658
- [=](sycl::nd_item<1> item_ct1) {
659
- unary_op_abs_kernel(src, dst_ptr, k_elements, item_ct1);
660
- });
589
+
590
+ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
591
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
592
+ float start, stop, step;
593
+ memcpy(&start, dst->op_params, sizeof(float));
594
+ memcpy(&stop, (float *) dst->op_params + 1, sizeof(float));
595
+ memcpy(&step, (float *) dst->op_params + 2, sizeof(float));
596
+ dpct::queue_ptr stream = ctx.stream();
597
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
598
+ float * dst_ptr = (float *)dst->data;
599
+ const int k = (int)ggml_nelements(dst);
600
+ const int num_blocks = ceil_div(k, SYCL_ARANGE_BLOCK_SIZE);
601
+ stream->parallel_for(
602
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
603
+ sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
604
+ [=](sycl::nd_item<1> item_ct1) {
605
+ arange_kernel(dst_ptr, k, start, step, item_ct1);
661
606
  });
662
607
  }
663
608
 
664
- static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
665
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
666
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
667
- const int num_blocks = ceil_div(k_elements, 256);
668
- stream->parallel_for(
669
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
670
- sycl::range<1>(256)),
671
- [=](sycl::nd_item<1> item_ct1) {
672
- unary_op_elu_kernel(src, dst_ptr, k_elements, item_ct1);
673
- });
674
- });
609
+ } // namespace ggml_sycl_detail
610
+
611
+
612
+
613
+ static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
614
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
615
+ return op_sgn(x);
616
+ });
617
+ }
618
+
619
+
620
+ static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
621
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
622
+ return op_abs(x);
623
+ });
675
624
  }
676
625
 
626
+ static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
627
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
628
+ return op_elu(x);
629
+ });
630
+ }
677
631
  static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
678
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
679
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
680
- const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
681
- stream->parallel_for(
682
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
683
- sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
684
- [=](sycl::nd_item<1> item_ct1) {
685
- unary_op_silu_kernel(src, dst_ptr, k_elements, item_ct1);
686
- });
687
- });
632
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
633
+ return op_silu(x);
634
+ });
688
635
  }
689
636
 
690
637
  static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
691
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
692
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
693
- const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
694
- stream->parallel_for(
695
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
696
- sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
697
- [=](sycl::nd_item<1> item_ct1) {
698
- unary_op_gelu_kernel(src, dst_ptr, k_elements, item_ct1);
699
- });
700
- });
638
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
639
+ return op_gelu(x);
640
+ });
701
641
  }
702
642
 
703
- static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
704
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
705
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
706
- const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
707
- stream->parallel_for(
708
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
709
- sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
710
- [=](sycl::nd_item<1> item_ct1) {
711
- unary_op_gelu_quick_kernel(src, dst_ptr, k_elements, item_ct1);
712
- });
713
- });
643
+ static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
644
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
645
+ return op_gelu_quick(x);
646
+ });
714
647
  }
715
648
 
716
- static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
717
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
718
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
719
- const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
720
- stream->parallel_for(
721
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
722
- sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
723
- [=](sycl::nd_item<1> item_ct1) {
724
- unary_op_gelu_erf_kernel(src, dst_ptr, k_elements, item_ct1);
725
- });
726
- });
649
+ static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
650
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
651
+ return op_gelu_erf(x);
652
+ });
727
653
  }
728
654
 
729
655
  static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
730
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
731
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
732
- const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
733
- stream->parallel_for(
734
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
735
- sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
736
- [=](sycl::nd_item<1> item_ct1) {
737
- unary_op_tanh_kernel(src, dst_ptr, k_elements, item_ct1);
738
- });
739
- });
656
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
657
+ return op_tanh(x);
658
+ });
740
659
  }
741
660
 
742
661
  static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
743
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
744
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
745
- const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
746
- stream->parallel_for(
747
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
748
- sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
749
- [=](sycl::nd_item<1> item_ct1) {
750
- unary_op_relu_kernel(src, dst_ptr, k_elements, item_ct1);
751
- });
752
- });
662
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
663
+ return op_relu(x);
664
+ });
753
665
  }
754
666
 
755
667
  static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
756
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
757
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
758
- const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
759
- stream->parallel_for(
760
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
761
- sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
762
- [=](sycl::nd_item<1> item_ct1) {
763
- unary_op_hardsigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
764
- });
765
- });
668
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
669
+ return op_hardsigmoid(x);
670
+ });
766
671
  }
767
672
 
768
673
  static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
769
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
770
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
771
- const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
772
- stream->parallel_for(
773
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
774
- sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
775
- [=](sycl::nd_item<1> item_ct1) {
776
- unary_op_hardswish_kernel(src, dst_ptr, k_elements, item_ct1);
777
- });
778
- });
674
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
675
+ return op_hardswish(x);
676
+ });
779
677
  }
780
678
 
781
679
  static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
782
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
783
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
784
- const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
785
- stream->parallel_for(
786
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
787
- sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
788
- [=](sycl::nd_item<1> item_ct1) {
789
- unary_op_exp_kernel(src, dst_ptr, k_elements, item_ct1);
790
- });
791
- });
680
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
681
+ return op_exp(x);
682
+ });
792
683
  }
793
684
 
794
685
  static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -805,42 +696,22 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
805
696
  }
806
697
 
807
698
  static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
808
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
809
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
810
- const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
811
- stream->parallel_for(
812
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
813
- sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
814
- [=](sycl::nd_item<1> item_ct1) {
815
- unary_op_neg_kernel(src, dst_ptr, k_elements, item_ct1);
816
- });
817
- });
699
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
700
+ return op_neg(x);
701
+ });
818
702
  }
819
703
 
704
+
820
705
  static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
821
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
822
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
823
- const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
824
- stream->parallel_for(
825
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
826
- sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
827
- [=](sycl::nd_item<1> item_ct1) {
828
- unary_op_step_kernel(src, dst_ptr, k_elements, item_ct1);
829
- });
830
- });
706
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
707
+ return op_step(x);
708
+ });
831
709
  }
832
710
 
833
711
  static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
834
- ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
835
- [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
836
- const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
837
- stream->parallel_for(
838
- sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
839
- sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
840
- [=](sycl::nd_item<1> item_ct1) {
841
- unary_op_sigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
842
- });
843
- });
712
+ ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
713
+ return op_sigmoid(x);
714
+ });
844
715
  }
845
716
 
846
717
  static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -919,14 +790,6 @@ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_te
919
790
  });
920
791
  }
921
792
 
922
- static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
923
- ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
924
- [](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
925
- queue_ptr stream) {
926
- ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
927
- });
928
- }
929
-
930
793
  static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
931
794
  float min_val;
932
795
  float max_val;
@@ -944,6 +807,58 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
944
807
  }, min_val, max_val);
945
808
  }
946
809
 
810
+ static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
811
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
812
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
813
+ const int num_blocks = ceil_div(k_elements, 256);
814
+ stream->parallel_for(
815
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
816
+ sycl::range<1>(256)),
817
+ [=](sycl::nd_item<1> item_ct1) {
818
+ unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1);
819
+ });
820
+ });
821
+ }
822
+
823
+ static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
824
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
825
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
826
+ const int num_blocks = ceil_div(k_elements, 256);
827
+ stream->parallel_for(
828
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
829
+ sycl::range<1>(256)),
830
+ [=](sycl::nd_item<1> item_ct1) {
831
+ unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1);
832
+ });
833
+ });
834
+ }
835
+
836
+ static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
837
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
838
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
839
+ const int num_blocks = ceil_div(k_elements, 256);
840
+ stream->parallel_for(
841
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
842
+ sycl::range<1>(256)),
843
+ [=](sycl::nd_item<1> item_ct1) {
844
+ unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1);
845
+ });
846
+ });
847
+ }
848
+
849
+ static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
850
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
851
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
852
+ const int num_blocks = ceil_div(k_elements, 256);
853
+ stream->parallel_for(
854
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
855
+ sycl::range<1>(256)),
856
+ [=](sycl::nd_item<1> item_ct1) {
857
+ unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1);
858
+ });
859
+ });
860
+ }
861
+
947
862
  static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
948
863
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
949
864
  GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
@@ -996,6 +911,98 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
996
911
  });
997
912
  }
998
913
 
914
+ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
915
+ x = sycl::fmin(x, limit);
916
+ g = sycl::fmax(sycl::fmin(g, limit), -limit);
917
+
918
+ float out_glu = x / (1.0f + sycl::native::exp(-x * alpha));
919
+ out_glu = out_glu * (1.0f + g);
920
+ return out_glu;
921
+ }
922
+
923
+
924
+ template <typename T>
925
+ static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
926
+ const int64_t n, const int64_t o0, const int64_t o1,
927
+ float alpha, float limit, sycl::nd_item<3> item_ct1) {
928
+ const int64_t i = int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
929
+
930
+ if (i >= k) {
931
+ return;
932
+ }
933
+
934
+ const int64_t j0 = (i / n) * o0 + (i % n);
935
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
936
+
937
+ float xi = x[j0];
938
+ float gi = g[j1];
939
+
940
+ dst[i] = ggml_sycl_op_swiglu_oai_single(xi, gi, alpha, limit);
941
+ }
942
+
943
+ template <typename T>
944
+ static void swiglu_oai_sycl(const T * x,
945
+ const T * g,
946
+ T * dst,
947
+ const int64_t k,
948
+ const int64_t n,
949
+ const int64_t o0,
950
+ const int64_t o1,
951
+ const float alpha,
952
+ const float limit,
953
+ dpct::queue_ptr stream) {
954
+ const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
955
+ stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
956
+ sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
957
+ [=](sycl::nd_item<3> item_ct1) {
958
+ swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
959
+ });
960
+ }
961
+
962
+ void ggml_sycl_op_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
963
+ const ggml_tensor * src0 = dst->src[0];
964
+ const ggml_tensor * src1 = dst->src[1];
965
+ void * src0_d = src0->data;
966
+ void * src1_d = src1 ? src1->data : src0->data;
967
+ const int64_t src0_o = src0->nb[1];
968
+ const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
969
+ void * dst_d = dst->data;
970
+ const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
971
+ dpct::queue_ptr stream = ctx.stream();
972
+
973
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
974
+ GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
975
+ GGML_ASSERT(ggml_is_contiguous(dst));
976
+
977
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
978
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
979
+ GGML_ASSERT(src0->type == dst->type);
980
+ GGML_ASSERT(dst->ne[0] == nc);
981
+ GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
982
+
983
+ if (src1) {
984
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
985
+ GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
986
+ GGML_ASSERT(src1->ne[0] == nc);
987
+ GGML_ASSERT(src0->type == src1->type);
988
+ }
989
+
990
+ //const int32_t swapped = ((const int32_t *) dst->op_params)[1];
991
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
992
+ const float alpha = ggml_get_op_params_f32(dst, 2);
993
+ const float limit = ggml_get_op_params_f32(dst, 3);
994
+
995
+ float * src0_p = (float *) src0_d;
996
+ float * src1_p = (float *) src1_d;
997
+
998
+ if (!src1) {
999
+ src0_p += swapped ? nc : 0;
1000
+ src1_p += swapped ? 0 : nc;
1001
+ }
1002
+
1003
+ swiglu_oai_sycl(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
1004
+ }
1005
+
999
1006
  static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1000
1007
  ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
1001
1008
  [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
@@ -1119,10 +1126,6 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1119
1126
  ggml_sycl_op_upscale(ctx, dst);
1120
1127
  }
1121
1128
 
1122
- void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1123
- scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1124
- ggml_sycl_op_pad(ctx, dst);
1125
- }
1126
1129
 
1127
1130
  void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1128
1131
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
@@ -1159,6 +1162,11 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1159
1162
  ggml_sycl_op_swiglu(ctx, dst);
1160
1163
  }
1161
1164
 
1165
+ void ggml_sycl_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1166
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1167
+ ggml_sycl_op_swiglu_oai(ctx, dst);
1168
+ }
1169
+
1162
1170
  void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1163
1171
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1164
1172
  ggml_sycl_op_geglu_erf(ctx, dst);
@@ -1168,3 +1176,28 @@ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1168
1176
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1169
1177
  ggml_sycl_op_geglu_quick(ctx, dst);
1170
1178
  }
1179
+
1180
+ void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1181
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
1182
+ ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst);
1183
+ }
1184
+
1185
+ void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1186
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1187
+ ggml_sycl_op_floor(ctx, dst);
1188
+ }
1189
+
1190
+ void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1191
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1192
+ ggml_sycl_op_ceil(ctx, dst);
1193
+ }
1194
+
1195
+ void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1196
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1197
+ ggml_sycl_op_round(ctx, dst);
1198
+ }
1199
+
1200
+ void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1201
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1202
+ ggml_sycl_op_trunc(ctx, dst);
1203
+ }