whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -1,12 +1,19 @@
1
1
  #include "common.hpp"
2
+ #include "ggml-sycl/presets.hpp"
2
3
  #include "ggml.h"
3
4
  #include "element_wise.hpp"
4
5
 
6
+ #define SYCL_GLOBAL_ID_LOOP(K, ITEM) \
7
+ for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0))
8
+
9
+ #define SYCL_LOCAL_ID_CALC(ITEM, IDX) \
10
+ (ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))
11
+
12
+
5
13
  static void acc_f32(const float * x, const float * y, float * dst, const int ne,
6
14
  const int ne10, const int ne11, const int ne12,
7
- const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
8
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
9
- item_ct1.get_local_id(2);
15
+ const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) {
16
+ const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0);
10
17
  if (i >= ne) {
11
18
  return;
12
19
  }
@@ -21,239 +28,280 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
21
28
  }
22
29
  }
23
30
 
31
+ /* Unary OP funcs */
24
32
  template<typename T>
25
- static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
26
- for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
27
- dst[i] = x[i] > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x[i] < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
28
- }
33
+ static __dpct_inline__ T op_sgn(T x) {
34
+ return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
29
35
  }
30
36
 
31
37
  template<typename T>
32
- static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
33
- for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
34
- dst[i] = sycl::fabs(x[i]);
35
- }
38
+ static __dpct_inline__ T op_abs(T x) {
39
+ return sycl::fabs(x);
36
40
  }
37
41
 
38
42
  template<typename T>
39
- static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
40
- for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
41
- dst[i] = (x[i] > static_cast<T>(0.f)) ? x[i] : sycl::expm1(x[i]);
42
- }
43
+ static __dpct_inline__ T op_elu(T x) {
44
+ return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
43
45
  }
44
46
 
45
47
  template<typename T>
46
- static void gelu(const T * x, T * dst, const int k,
47
- const sycl::nd_item<3> &item_ct1) {
48
+ static __dpct_inline__ T op_gelu(T x) {
48
49
  const T GELU_COEF_A = static_cast<T>(0.044715f);
49
50
  const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
50
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
51
- item_ct1.get_local_id(2);
51
+ return static_cast<T>(0.5f) * x *
52
+ (static_cast<T>(1.0f) +
53
+ sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
54
+ }
52
55
 
53
- if (i >= k) {
54
- return;
55
- }
56
+ template<typename T>
57
+ static __dpct_inline__ T op_silu(T x) {
58
+ return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
59
+ }
56
60
 
57
- float xi = x[i];
58
- dst[i] = static_cast<T>(0.5f) * xi *
59
- (static_cast<T>(1.0f) +
60
- sycl::tanh(SQRT_2_OVER_PI * xi * (static_cast<T>(1.0f) + GELU_COEF_A * xi * xi)));
61
+ template<typename T>
62
+ static __dpct_inline__ T op_gelu_quick(T x) {
63
+ const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
64
+ return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
61
65
  }
62
66
 
63
67
  template<typename T>
64
- static void silu(const T * x, T * dst, const int k,
65
- const sycl::nd_item<3> &item_ct1) {
66
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
67
- item_ct1.get_local_id(2);
68
+ static __dpct_inline__ T op_gelu_erf(T x) {
69
+ const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
70
+ return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
71
+ }
68
72
 
69
- if (i >= k) {
70
- return;
71
- }
72
- dst[i] = x[i] / (static_cast<T>(1.0f) + sycl::native::exp(-x[i]));
73
+ template<typename T>
74
+ static __dpct_inline__ T op_tanh(T x) {
75
+ return sycl::tanh(x);
73
76
  }
74
77
 
75
78
  template<typename T>
76
- static void gelu_quick(const T *x, T *dst, int k,
77
- const sycl::nd_item<3> &item_ct1) {
78
- const float GELU_QUICK_COEF = -1.702f;
79
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
80
- item_ct1.get_local_id(2);
81
- if (i >= k) {
82
- return;
83
- }
84
- dst[i] = x[i] * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i])));
79
+ static __dpct_inline__ T op_relu(T x) {
80
+ return sycl::fmax(x, static_cast<T>(0));
85
81
  }
86
82
 
87
83
  template<typename T>
88
- static void tanh(const T *x, T *dst, int k,
89
- const sycl::nd_item<3> &item_ct1) {
90
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
91
- item_ct1.get_local_id(2);
92
- if (i >= k) {
93
- return;
94
- }
95
- dst[i] = sycl::tanh((x[i]));
84
+ static __dpct_inline__ T op_sigmoid(T x) {
85
+ return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
96
86
  }
97
87
 
98
88
  template<typename T>
99
- static void relu(const T * x, T * dst, const int k,
100
- const sycl::nd_item<3> &item_ct1) {
101
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
102
- item_ct1.get_local_id(2);
89
+ static __dpct_inline__ T op_sqrt(T x) {
90
+ return sycl::sqrt(x);
91
+ }
103
92
 
104
- if (i >= k) {
105
- return;
106
- }
107
- dst[i] = sycl::fmax((x[i]), static_cast<T>(0));
93
+ template<typename T>
94
+ static __dpct_inline__ T op_sin(T x) {
95
+ return sycl::sin(x);
108
96
  }
109
97
 
110
98
  template<typename T>
111
- static void sigmoid(const T * x, T * dst, const int k,
112
- const sycl::nd_item<3> &item_ct1) {
113
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
114
- item_ct1.get_local_id(2);
99
+ static __dpct_inline__ T op_cos(T x) {
100
+ return sycl::cos(x);
101
+ }
115
102
 
116
- if (i >= k) {
117
- return;
103
+ template<typename T>
104
+ static __dpct_inline__ T op_hardsigmoid(T x) {
105
+ return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
106
+ }
107
+
108
+ template<typename T>
109
+ static __dpct_inline__ T op_hardswish(T x) {
110
+ return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
111
+ }
112
+
113
+ template<typename T>
114
+ static __dpct_inline__ T op_exp(T x) {
115
+ return sycl::exp(x);
116
+ }
117
+
118
+ template<typename T>
119
+ static __dpct_inline__ T op_log(T x) {
120
+ if (x <= static_cast<T>(0)) {
121
+ return neg_infinity<T>();
118
122
  }
119
- dst[i] = 1.0f / (static_cast<T>(1.0f) + sycl::native::exp(-x[i]));
123
+ return sycl::log(x);
120
124
  }
121
125
 
122
126
  template<typename T>
123
- static void sqrt(const T * x, T * dst, const int k,
124
- const sycl::nd_item<3> &item_ct1) {
125
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
126
- item_ct1.get_local_id(2);
127
+ static __dpct_inline__ T op_neg(T x) {
128
+ return -x;
129
+ }
127
130
 
128
- if (i >= k) {
129
- return;
131
+ template<typename T>
132
+ static __dpct_inline__ T op_step(T x) {
133
+ return (x > static_cast<T>(0.0f)) ? static_cast<T>(1.0f) : static_cast<T>(0.0f);
134
+ }
135
+
136
+ template<typename T>
137
+ static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
138
+ T neg_slope_T = static_cast<T>(negative_slope);
139
+ return sycl::fmax(x, static_cast<T>(0)) +
140
+ sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
141
+ }
142
+
143
+ template<typename T>
144
+ static __dpct_inline__ T op_sqr(T x) {
145
+ return x * x;
146
+ }
147
+
148
+ template<typename T>
149
+ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
150
+ return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
151
+ }
152
+
153
+ template<typename T>
154
+ static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
155
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
156
+ dst[i] = op_sgn(x[i]);
130
157
  }
131
- dst[i] = sycl::sqrt(x[i]);
132
158
  }
133
159
 
134
160
  template<typename T>
135
- static void sin(const T * x, T * dst, const int k,
136
- const sycl::nd_item<3> &item_ct1) {
137
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
138
- item_ct1.get_local_id(2);
161
+ static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
162
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
163
+ dst[i] = op_abs(x[i]);
164
+ }
165
+ }
139
166
 
140
- if (i >= k) {
141
- return;
167
+ template<typename T>
168
+ static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
169
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
170
+ dst[i] = op_elu(x[i]);
142
171
  }
143
- dst[i] = sycl::sin(x[i]);
144
172
  }
145
173
 
146
174
  template<typename T>
147
- static void cos(const T * x, T * dst, const int k,
148
- const sycl::nd_item<3> &item_ct1) {
149
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
150
- item_ct1.get_local_id(2);
175
+ static void unary_op_gelu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
176
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
177
+ dst[i] = op_gelu(x[i]);
178
+ }
179
+ }
151
180
 
152
- if (i >= k) {
153
- return;
181
+ template<typename T>
182
+ static void unary_op_silu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
183
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
184
+ dst[i] = op_silu(x[i]);
154
185
  }
155
- dst[i] = sycl::cos(x[i]);
156
186
  }
157
187
 
158
188
  template<typename T>
159
- static void hardsigmoid(const T * x, T * dst, const int k,
160
- const sycl::nd_item<3> &item_ct1) {
161
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
162
- item_ct1.get_local_id(2);
189
+ static void unary_op_gelu_quick_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
190
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
191
+ dst[i] = op_gelu_quick(x[i]);
192
+ }
193
+ }
163
194
 
164
- if (i >= k) {
165
- return;
195
+ template<typename T>
196
+ static void unary_op_gelu_erf_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
197
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
198
+ dst[i] = op_gelu_erf(x[i]);
199
+ }
200
+ }
201
+
202
+ template<typename T>
203
+ static void unary_op_tanh_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
204
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
205
+ dst[i] = op_tanh(x[i]);
166
206
  }
167
- dst[i] = sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
168
207
  }
169
208
 
170
209
  template<typename T>
171
- static void hardswish(const T * x, T * dst, const int k,
172
- const sycl::nd_item<3> &item_ct1) {
173
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
174
- item_ct1.get_local_id(2);
210
+ static void unary_op_relu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
211
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
212
+ dst[i] = op_relu(x[i]);
213
+ }
214
+ }
175
215
 
176
- if (i >= k) {
177
- return;
216
+ template<typename T>
217
+ static void unary_op_sigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
218
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
219
+ dst[i] = op_sigmoid(x[i]);
178
220
  }
179
- dst[i] = x[i] * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
180
221
  }
181
222
 
182
223
  template<typename T>
183
- static void exp(const T * x, T * dst, const int k,
184
- const sycl::nd_item<3> &item_ct1) {
185
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
186
- item_ct1.get_local_id(2);
224
+ static void unary_op_sqrt_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
225
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
226
+ dst[i] = op_sqrt(x[i]);
227
+ }
228
+ }
187
229
 
188
- if (i >= k) {
189
- return;
230
+ template<typename T>
231
+ static void unary_op_sin_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
232
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
233
+ dst[i] = op_sin(x[i]);
190
234
  }
191
- dst[i] = sycl::exp(x[i]);
192
235
  }
193
236
 
194
237
  template<typename T>
195
- static void log(const T * x, T * dst, const int k,
196
- const sycl::nd_item<3> &item_ct1) {
197
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
198
- item_ct1.get_local_id(2);
238
+ static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
239
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
240
+ dst[i] = op_cos(x[i]);
241
+ }
242
+ }
199
243
 
200
- if (i >= k) {
201
- return;
244
+ template<typename T>
245
+ static void unary_op_hardsigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
246
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
247
+ dst[i] = op_hardsigmoid(x[i]);
202
248
  }
203
- T xi = x[i];
204
- if (xi <= 0) {
205
- dst[i] = neg_infinity<T>();
206
- } else {
207
- dst[i] = sycl::log(xi);
249
+ }
250
+
251
+ template<typename T>
252
+ static void unary_op_hardswish_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
253
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
254
+ dst[i] = op_hardswish(x[i]);
208
255
  }
209
256
  }
210
257
 
211
258
  template<typename T>
212
- static void neg(const T * x, T * dst, const int k,
213
- const sycl::nd_item<3> &item_ct1) {
214
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
215
- item_ct1.get_local_id(2);
259
+ static void unary_op_exp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
260
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
261
+ dst[i] = op_exp(x[i]);
262
+ }
263
+ }
216
264
 
217
- if (i >= k) {
218
- return;
265
+ template<typename T>
266
+ static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
267
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
268
+ dst[i] = op_log(x[i]);
219
269
  }
220
- dst[i] = -x[i];
221
270
  }
222
271
 
223
272
  template<typename T>
224
- static void step(const T * x, T * dst, const int k,
225
- const sycl::nd_item<3> &item_ct1) {
226
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
227
- item_ct1.get_local_id(2);
273
+ static void unary_op_neg_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
274
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
275
+ dst[i] = op_neg(x[i]);
276
+ }
277
+ }
228
278
 
229
- if (i >= k) {
230
- return;
279
+ template<typename T>
280
+ static void unary_op_step_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
281
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
282
+ dst[i] = op_step(x[i]);
231
283
  }
232
- dst[i] = x[i] > static_cast<T>(0.0f);
233
284
  }
234
285
 
235
286
  template<typename T>
236
- static void leaky_relu(const T *x, T *dst, const int k, const float negative_slope,
237
- const sycl::nd_item<3> &item_ct1) {
238
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
239
- item_ct1.get_local_id(2);
240
- if (i >= k) {
241
- return;
287
+ static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
288
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
289
+ dst[i] = op_leaky_relu(x[i], negative_slope);
242
290
  }
243
- dst[i] = sycl::fmax((x[i]), static_cast<T>(0)) +
244
- sycl::fmin((x[i]), static_cast<T>(0.0f)) * negative_slope;
245
291
  }
246
292
 
247
293
  template<typename T>
248
- static void sqr(const T * x, T * dst, const int k,
249
- const sycl::nd_item<3> &item_ct1) {
250
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
251
- item_ct1.get_local_id(2);
294
+ static void unary_op_sqr_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
295
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
296
+ dst[i] = op_sqr(x[i]);
297
+ }
298
+ }
252
299
 
253
- if (i >= k) {
254
- return;
300
+ template<typename T>
301
+ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1, float min_val, float max_val) {
302
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
303
+ dst[i] = op_clamp(x[i], min_val, max_val);
255
304
  }
256
- dst[i] = x[i] * x[i];
257
305
  }
258
306
 
259
307
  template<typename T>
@@ -272,10 +320,10 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
272
320
  int i12 = (index / (ne10 * ne11)) % ne12;
273
321
  int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
274
322
 
275
- int i00 = i10 / sf0;
276
- int i01 = i11 / sf1;
277
- int i02 = i12 / sf2;
278
- int i03 = i13 / sf3;
323
+ int i00 = static_cast<int>(i10 / sf0);
324
+ int i01 = static_cast<int>(i11 / sf1);
325
+ int i02 = static_cast<int>(i12 / sf2);
326
+ int i03 = static_cast<int>(i13 / sf3);
279
327
 
280
328
  dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
281
329
  }
@@ -283,8 +331,7 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
283
331
  template <typename T>
284
332
  static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
285
333
  const sycl::nd_item<3> &item_ct1) {
286
- int nidx = item_ct1.get_local_id(2) +
287
- item_ct1.get_group(2) * item_ct1.get_local_range(2);
334
+ int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
288
335
  if (nidx >= ne0) {
289
336
  return;
290
337
  }
@@ -301,285 +348,72 @@ static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne
301
348
  }
302
349
  }
303
350
 
304
-
305
351
  template<typename T>
306
352
  static void clamp(const T * x, T * dst, const float min, const float max, const int k,
307
- const sycl::nd_item<3> &item_ct1) {
308
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
309
- item_ct1.get_local_id(2);
310
-
311
- if (i >= k) {
312
- return;
353
+ const sycl::nd_item<1> &item_ct1) {
354
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
355
+ dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
313
356
  }
314
-
315
- dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
316
- }
317
-
318
- static void acc_f32_sycl(const float *x, const float *y, float *dst,
319
- const int n_elements, const int ne10, const int ne11,
320
- const int ne12, const int nb1, const int nb2,
321
- const int offset, queue_ptr stream) {
322
- int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
323
- stream->parallel_for(
324
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
325
- sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
326
- sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
327
- [=](sycl::nd_item<3> item_ct1) {
328
- acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
329
- item_ct1);
330
- });
331
- }
332
-
333
- template<typename T>
334
- static void gelu_sycl(const T *x, T *dst, const int k,
335
- queue_ptr stream) {
336
- const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
337
- stream->parallel_for(
338
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
339
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
340
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
341
- [=](sycl::nd_item<3> item_ct1) {
342
- gelu(x, dst, k, item_ct1);
343
- });
344
- }
345
-
346
- template<typename T>
347
- static void silu_sycl(const T *x, T *dst, const int k,
348
- queue_ptr stream) {
349
- const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
350
- stream->parallel_for(
351
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
352
- sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
353
- sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
354
- [=](sycl::nd_item<3> item_ct1) {
355
- silu(x, dst, k, item_ct1);
356
- });
357
- }
358
-
359
- template<typename T>
360
- static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
361
- // hard code for now
362
- const int num_blocks = ceil_div(k, 256);
363
- stream->parallel_for(
364
- sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
365
- sgn(x, dst, k, item_ct1);
366
- });
367
- }
368
-
369
- template<typename T>
370
- static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
371
- // hard code for now
372
- const int num_blocks = ceil_div(k, 256);
373
- stream->parallel_for(
374
- sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
375
- abs_op(x, dst, k, item_ct1);
376
- });
377
- }
378
-
379
-
380
- template<typename T>
381
- static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
382
- // hard code for now
383
- const int num_blocks = ceil_div(k, 256);
384
- stream->parallel_for(
385
- sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
386
- elu_op(x, dst, k, item_ct1);
387
- });
388
- }
389
-
390
- template<typename T>
391
- static void gelu_quick_sycl(const T *x, T *dst, const int k,
392
- queue_ptr stream) {
393
- const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
394
- stream->parallel_for(
395
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
396
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
397
- sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
398
- [=](sycl::nd_item<3> item_ct1) {
399
- gelu_quick(x, dst, k, item_ct1);
400
- });
401
- }
402
-
403
- template<typename T>
404
- static void tanh_sycl(const T *x, T *dst, const int k,
405
- queue_ptr stream) {
406
- const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
407
- stream->parallel_for(
408
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
409
- sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
410
- sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
411
- [=](sycl::nd_item<3> item_ct1) {
412
- tanh(x, dst, k, item_ct1);
413
- });
414
- }
415
-
416
- template<typename T>
417
- static void relu_sycl(const T *x, T *dst, const int k,
418
- queue_ptr stream) {
419
- const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
420
- stream->parallel_for(
421
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
422
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
423
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
424
- [=](sycl::nd_item<3> item_ct1) {
425
- relu(x, dst, k, item_ct1);
426
- });
427
- }
428
-
429
- template<typename T>
430
- static void hardsigmoid_sycl(const T *x, T *dst, const int k,
431
- queue_ptr stream) {
432
- const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
433
- stream->parallel_for(
434
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
435
- sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
436
- sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
437
- [=](sycl::nd_item<3> item_ct1) {
438
- hardsigmoid(x, dst, k, item_ct1);
439
- });
440
- }
441
-
442
- template<typename T>
443
- static void hardswish_sycl(const T *x, T *dst, const int k,
444
- queue_ptr stream) {
445
- const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
446
- stream->parallel_for(
447
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
448
- sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
449
- sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
450
- [=](sycl::nd_item<3> item_ct1) {
451
- hardswish(x, dst, k, item_ct1);
452
- });
453
- }
454
-
455
- template<typename T>
456
- static void exp_sycl(const T *x, T *dst, const int k,
457
- queue_ptr stream) {
458
- const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
459
- stream->parallel_for(
460
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
461
- sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
462
- sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
463
- [=](sycl::nd_item<3> item_ct1) {
464
- exp(x, dst, k, item_ct1);
465
- });
466
- }
467
-
468
- template<typename T>
469
- static void log_sycl(const T *x, T *dst, const int k,
470
- queue_ptr stream) {
471
- const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
472
- stream->parallel_for(
473
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
474
- sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
475
- sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
476
- [=](sycl::nd_item<3> item_ct1) {
477
- log(x, dst, k, item_ct1);
478
- });
479
- }
480
-
481
- template<typename T>
482
- static void neg_sycl(const T *x, T *dst, const int k,
483
- queue_ptr stream) {
484
- const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
485
- stream->parallel_for(
486
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
487
- sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
488
- sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
489
- [=](sycl::nd_item<3> item_ct1) {
490
- neg(x, dst, k, item_ct1);
491
- });
492
- }
493
-
494
- template<typename T>
495
- static void step_sycl(const T *x, T *dst, const int k,
496
- queue_ptr stream) {
497
- const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
498
- stream->parallel_for(
499
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
500
- sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
501
- sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
502
- [=](sycl::nd_item<3> item_ct1) {
503
- step(x, dst, k, item_ct1);
504
- });
505
357
  }
506
358
 
507
359
  template<typename T>
508
- static void sigmoid_sycl(const T *x, T *dst, const int k,
509
- queue_ptr stream) {
510
- const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
511
- stream->parallel_for(
512
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
513
- sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
514
- sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)),
515
- [=](sycl::nd_item<3> item_ct1) {
516
- sigmoid(x, dst, k, item_ct1);
517
- });
360
+ static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
361
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
362
+ const int64_t j0 = (i / n) * o0 + (i % n);
363
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
364
+ dst[i] = op_gelu(x[j0]) * g[j1];
365
+ }
518
366
  }
519
367
 
520
368
  template<typename T>
521
- static void sqrt_sycl(const T *x, T *dst, const int k,
522
- queue_ptr stream) {
523
- const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
524
- stream->parallel_for(
525
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
526
- sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
527
- sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
528
- [=](sycl::nd_item<3> item_ct1) {
529
- sqrt(x, dst, k, item_ct1);
530
- });
369
+ static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
370
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
371
+ const int64_t j0 = (i / n) * o0 + (i % n);
372
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
373
+ dst[i] = op_relu(x[j0]) * g[j1];
374
+ }
531
375
  }
532
376
 
533
377
  template<typename T>
534
- static void sin_sycl(const T *x, T *dst, const int k,
535
- queue_ptr stream) {
536
- const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
537
- stream->parallel_for(
538
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
539
- sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
540
- sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
541
- [=](sycl::nd_item<3> item_ct1) {
542
- sin(x, dst, k, item_ct1);
543
- });
378
+ static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
379
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
380
+ const int64_t j0 = (i / n) * o0 + (i % n);
381
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
382
+ dst[i] = op_silu(x[j0]) * g[j1];
383
+ }
544
384
  }
545
385
 
546
386
  template<typename T>
547
- static void cos_sycl(const T *x, T *dst, const int k,
548
- queue_ptr stream) {
549
- const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
550
- stream->parallel_for(
551
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
552
- sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
553
- sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
554
- [=](sycl::nd_item<3> item_ct1) {
555
- cos(x, dst, k, item_ct1);
556
- });
387
+ static void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
388
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
389
+ const int64_t j0 = (i / n) * o0 + (i % n);
390
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
391
+ dst[i] = op_gelu_erf(x[j0]) * g[j1];
392
+ }
557
393
  }
558
394
 
559
395
  template<typename T>
560
- static void leaky_relu_sycl(const T *x, T *dst, const int k,
561
- const float negative_slope,
562
- queue_ptr stream) {
563
- const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
564
- stream->parallel_for(
565
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
566
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
567
- sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
568
- [=](sycl::nd_item<3> item_ct1) {
569
- leaky_relu(x, dst, k, negative_slope, item_ct1);
570
- });
396
+ static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
397
+ SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
398
+ const int64_t j0 = (i / n) * o0 + (i % n);
399
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
400
+ dst[i] = op_gelu_quick(x[j0]) * g[j1];
401
+ }
571
402
  }
572
403
 
573
- template<typename T>
574
- static void sqr_sycl(const T *x, T *dst, const int k,
575
- queue_ptr stream) {
576
- const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
404
+ namespace ggml_sycl_detail {
405
+ static void acc_f32_sycl(const float *x, const float *y, float *dst,
406
+ const int n_elements, const int ne10, const int ne11,
407
+ const int ne12, const int nb1, const int nb2,
408
+ const int offset, queue_ptr stream) {
409
+ int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
577
410
  stream->parallel_for(
578
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
579
- sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
580
- sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
581
- [=](sycl::nd_item<3> item_ct1) {
582
- sqr(x, dst, k, item_ct1);
411
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) *
412
+ sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
413
+ sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
414
+ [=](sycl::nd_item<1> item_ct1) {
415
+ acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
416
+ item_ct1);
583
417
  });
584
418
  }
585
419
 
@@ -589,11 +423,10 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
589
423
  const int ne12, const int ne13, const float sf0, const float sf1,
590
424
  const float sf2, const float sf3, queue_ptr stream) {
591
425
  int dst_size = ne10 * ne11 * ne12 * ne13;
592
- int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
426
+ int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE);
593
427
  sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
594
428
  stream->parallel_for(
595
- sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
596
- [=](sycl::nd_item<1> item_ct1) {
429
+ sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
597
430
  upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
598
431
  });
599
432
  }
@@ -602,35 +435,19 @@ template<typename T>
602
435
  static void pad_sycl(const T *x, T *dst, const int ne00,
603
436
  const int ne01, const int ne02, const int ne0,
604
437
  const int ne1, const int ne2, queue_ptr stream) {
605
- int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
438
+ int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
606
439
  sycl::range<3> gridDim(ne2, ne1, num_blocks);
607
440
  stream->parallel_for(
608
- sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
609
- sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
610
- [=](sycl::nd_item<3> item_ct1) {
611
- pad(x, dst, ne0, ne00, ne01, ne02, item_ct1);
612
- });
613
- }
614
-
615
- template<typename T>
616
- static void clamp_sycl(const T *x, T *dst, const float min,
617
- const float max, const int k,
618
- queue_ptr stream) {
619
- const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
620
- stream->parallel_for(
621
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
622
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
623
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
624
- [=](sycl::nd_item<3> item_ct1) {
625
- clamp(x, dst, min, max, k, item_ct1);
626
- });
441
+ sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
442
+ sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
443
+ [=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
627
444
  }
628
445
 
629
- inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
446
+ template<typename KernelInvoker, typename... Args>
447
+ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
630
448
  #if defined (GGML_SYCL_F16)
631
449
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
632
450
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
633
-
634
451
  #else
635
452
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
636
453
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
@@ -643,14 +460,14 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
643
460
  case GGML_TYPE_F16:
644
461
  {
645
462
  auto data_pts = cast_data<sycl::half>(dst);
646
- sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
463
+ kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
647
464
  break;
648
465
  }
649
466
  #endif
650
467
  case GGML_TYPE_F32:
651
468
  {
652
469
  auto data_pts = cast_data<float>(dst);
653
- sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
470
+ kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
654
471
  break;
655
472
  }
656
473
  default:
@@ -658,11 +475,11 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
658
475
  }
659
476
  }
660
477
 
661
- inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
478
+ template<typename KernelInvoker, typename... Args>
479
+ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
662
480
  #if defined (GGML_SYCL_F16)
663
481
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
664
482
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
665
-
666
483
  #else
667
484
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
668
485
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
@@ -670,19 +487,66 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
670
487
  GGML_ASSERT(dst->src[0]->type == dst->type);
671
488
  dpct::queue_ptr main_stream = ctx.stream();
672
489
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
490
+ const ggml_tensor * src0 = dst->src[0];
491
+ const ggml_tensor * src1 = dst->src[1];
492
+ const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;;
493
+ GGML_ASSERT(dst->ne[0] == nc);
494
+ GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
495
+ GGML_ASSERT(ggml_is_contiguous(dst));
496
+ const int32_t swapped = ((const int32_t *) dst->op_params)[1];
497
+ void * src0_d = src0->data;
498
+ void * src1_d = src1 ? src1->data : src0->data;
499
+ const int64_t src0_o = src0->nb[1];
500
+ const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
501
+ void * dst_d = dst->data;
502
+ if (src1) {
503
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
504
+ GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
505
+ GGML_ASSERT(src1->ne[0] == nc);
506
+ GGML_ASSERT(src0->type == src1->type);
507
+ }
673
508
  switch (dst->type) {
674
509
  #if defined (GGML_SYCL_F16)
675
510
  case GGML_TYPE_F16:
676
511
  {
677
- auto data_pts = cast_data<sycl::half>(dst);
678
- abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
512
+ sycl::half * src0_p = (sycl::half *) src0_d;
513
+ sycl::half * src1_p = (sycl::half *) src1_d;
514
+
515
+ if (!src1) {
516
+ src0_p += swapped ? nc : 0;
517
+ src1_p += swapped ? 0 : nc;
518
+ }
519
+ kernel_invoker(src0_p,
520
+ src1_p,
521
+ (sycl::half *) dst_d,
522
+ ggml_nelements(dst),
523
+ nc,
524
+ src0_o / sizeof(sycl::half),
525
+ src1_o / sizeof(sycl::half),
526
+ main_stream,
527
+ std::forward<Args>(args)...);
679
528
  break;
680
529
  }
681
530
  #endif
682
531
  case GGML_TYPE_F32:
683
532
  {
684
- auto data_pts = cast_data<float>(dst);
685
- abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
533
+ float * src0_p = (float *) src0_d;
534
+ float * src1_p = (float *) src1_d;
535
+
536
+ if (!src1) {
537
+ src0_p += swapped ? nc : 0;
538
+ src1_p += swapped ? 0 : nc;
539
+ }
540
+
541
+ kernel_invoker(src0_p,
542
+ src1_p,
543
+ (float *) dst_d,
544
+ ggml_nelements(dst),
545
+ nc,
546
+ src0_o / sizeof(float),
547
+ src1_o / sizeof(float),
548
+ main_stream,
549
+ std::forward<Args>(args)...);
686
550
  break;
687
551
  }
688
552
  default:
@@ -690,32 +554,41 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
690
554
  }
691
555
  }
692
556
 
693
-
694
- inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
557
+ template<typename KernelInvoker, typename... Args>
558
+ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
695
559
  #if defined (GGML_SYCL_F16)
696
560
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
697
561
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
698
-
699
562
  #else
700
563
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
701
564
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
702
565
  #endif
703
566
  GGML_ASSERT(dst->src[0]->type == dst->type);
567
+
704
568
  dpct::queue_ptr main_stream = ctx.stream();
705
569
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
570
+
571
+ const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];
572
+ const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];
573
+ const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
574
+ const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
706
575
  switch (dst->type) {
707
576
  #if defined (GGML_SYCL_F16)
708
577
  case GGML_TYPE_F16:
709
578
  {
710
579
  auto data_pts = cast_data<sycl::half>(dst);
711
- elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
580
+ kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
581
+ (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
582
+ main_stream, std::forward<Args>(args)...);
712
583
  break;
713
584
  }
714
585
  #endif
715
586
  case GGML_TYPE_F32:
716
587
  {
717
588
  auto data_pts = cast_data<float>(dst);
718
- elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
589
+ kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
590
+ (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
591
+ main_stream, std::forward<Args>(args)...);
719
592
  break;
720
593
  }
721
594
  default:
@@ -723,7 +596,8 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
723
596
  }
724
597
  }
725
598
 
726
- inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
599
+ template<typename KernelInvoker, typename... Args>
600
+ static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
727
601
  #if defined (GGML_SYCL_F16)
728
602
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
729
603
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -732,6 +606,7 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
732
606
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
733
607
  #endif
734
608
  GGML_ASSERT(dst->src[0]->type == dst->type);
609
+ GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
735
610
  dpct::queue_ptr main_stream = ctx.stream();
736
611
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
737
612
  switch (dst->type) {
@@ -739,14 +614,16 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
739
614
  case GGML_TYPE_F16:
740
615
  {
741
616
  auto data_pts = cast_data<sycl::half>(dst);
742
- silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
617
+ kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
618
+ (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
743
619
  break;
744
620
  }
745
621
  #endif
746
622
  case GGML_TYPE_F32:
747
623
  {
748
624
  auto data_pts = cast_data<float>(dst);
749
- silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
625
+ kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
626
+ (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
750
627
  break;
751
628
  }
752
629
  default:
@@ -754,623 +631,320 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
754
631
  }
755
632
  }
756
633
 
757
- inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
758
- #if defined (GGML_SYCL_F16)
759
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
760
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
761
- #else
762
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
763
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
764
- #endif
765
- GGML_ASSERT(dst->src[0]->type == dst->type);
766
- dpct::queue_ptr main_stream = ctx.stream();
767
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
768
- switch (dst->type) {
769
- #if defined (GGML_SYCL_F16)
770
- case GGML_TYPE_F16:
771
- {
772
- auto data_pts = cast_data<sycl::half>(dst);
773
- gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
774
- break;
775
- }
776
- #endif
777
- case GGML_TYPE_F32:
778
- {
779
- auto data_pts = cast_data<float>(dst);
780
- gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
781
- break;
782
- }
783
- default:
784
- GGML_ABORT("GGML tensor type not supported!\n");
785
- }
634
+ } // namespace ggml_sycl_detail
635
+
636
+
637
+
638
+ static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
639
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
640
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
641
+ const int num_blocks = ceil_div(k_elements, 256);
642
+ stream->parallel_for(
643
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
644
+ sycl::range<1>(256)),
645
+ [=](sycl::nd_item<1> item_ct1) {
646
+ unary_op_sgn_kernel(src, dst_ptr, k_elements, item_ct1);
647
+ });
648
+ });
786
649
  }
787
650
 
788
- inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
789
- #if defined (GGML_SYCL_F16)
790
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
791
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
792
- #else
793
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
794
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
795
- #endif
796
- GGML_ASSERT(dst->src[0]->type == dst->type);
797
- dpct::queue_ptr main_stream = ctx.stream();
798
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
799
- switch (dst->type) {
800
- #if defined (GGML_SYCL_F16)
801
- case GGML_TYPE_F16:
802
- {
803
- auto data_pts = cast_data<sycl::half>(dst);
804
- gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
805
- break;
806
- }
807
- #endif
808
- case GGML_TYPE_F32:
809
- {
810
- auto data_pts = cast_data<float>(dst);
811
- gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
812
- break;
813
- }
814
- default:
815
- GGML_ABORT("GGML tensor type not supported!\n");
816
- }
817
- }
818
-
819
- inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
820
- #if defined (GGML_SYCL_F16)
821
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
822
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
823
- #else
824
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
825
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
826
- #endif
827
- GGML_ASSERT(dst->src[0]->type == dst->type);
828
- dpct::queue_ptr main_stream = ctx.stream();
829
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
830
- switch (dst->type) {
831
- #if defined (GGML_SYCL_F16)
832
- case GGML_TYPE_F16:
833
- {
834
- auto data_pts = cast_data<sycl::half>(dst);
835
- tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
836
- break;
837
- }
838
- #endif
839
- case GGML_TYPE_F32:
840
- {
841
- auto data_pts = cast_data<float>(dst);
842
- tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
843
- break;
844
- }
845
- default:
846
- GGML_ABORT("GGML tensor type not supported!\n");
847
- }
848
- }
849
-
850
- inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
851
- #if defined (GGML_SYCL_F16)
852
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
853
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
854
- #else
855
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
856
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
857
- #endif
858
- GGML_ASSERT(dst->src[0]->type == dst->type);
859
- dpct::queue_ptr main_stream = ctx.stream();
860
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
861
-
862
- switch (dst->type) {
863
- #if defined (GGML_SYCL_F16)
864
- case GGML_TYPE_F16:
865
- {
866
- auto data_pts = cast_data<sycl::half>(dst);
867
- relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
868
- break;
869
- }
870
- #endif
871
- case GGML_TYPE_F32:
872
- {
873
- auto data_pts = cast_data<float>(dst);
874
- relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
875
- break;
876
- }
877
- default:
878
- GGML_ABORT("GGML tensor type not supported!\n");
879
- }
880
- }
881
-
882
- inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
883
- #if defined (GGML_SYCL_F16)
884
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
885
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
886
- #else
887
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
888
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
889
- #endif
890
- GGML_ASSERT(dst->src[0]->type == dst->type);
891
-
892
- dpct::queue_ptr main_stream = ctx.stream();
893
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
894
-
895
- switch (dst->type) {
896
- #if defined (GGML_SYCL_F16)
897
- case GGML_TYPE_F16:
898
- {
899
- auto data_pts = cast_data<sycl::half>(dst);
900
- hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
901
- break;
902
- }
903
- #endif
904
- case GGML_TYPE_F32:
905
- {
906
- auto data_pts = cast_data<float>(dst);
907
- hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
908
- break;
909
- }
910
- default:
911
- GGML_ABORT("GGML tensor type not supported!\n");
912
- }
913
- }
914
-
915
- inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
916
- #if defined (GGML_SYCL_F16)
917
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
918
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
919
- #else
920
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
921
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
922
- #endif
923
- GGML_ASSERT(dst->src[0]->type == dst->type);
924
- dpct::queue_ptr main_stream = ctx.stream();
925
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
926
- switch (dst->type) {
927
- #if defined (GGML_SYCL_F16)
928
- case GGML_TYPE_F16:
929
- {
930
- auto data_pts = cast_data<sycl::half>(dst);
931
- hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
932
- break;
933
- }
934
- #endif
935
- case GGML_TYPE_F32:
936
- {
937
- auto data_pts = cast_data<float>(dst);
938
- hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
939
- break;
940
- }
941
- default:
942
- GGML_ABORT("GGML tensor type not supported!\n");
943
- }
944
- }
945
-
946
- inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
947
- #if defined (GGML_SYCL_F16)
948
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
949
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
950
- #else
951
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
952
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
953
- #endif
954
- GGML_ASSERT(dst->src[0]->type == dst->type);
955
- dpct::queue_ptr main_stream = ctx.stream();
956
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
957
- switch (dst->type) {
958
- #if defined (GGML_SYCL_F16)
959
- case GGML_TYPE_F16:
960
- {
961
- auto data_pts = cast_data<sycl::half>(dst);
962
- exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
963
- break;
964
- }
965
- #endif
966
- case GGML_TYPE_F32:
967
- {
968
- auto data_pts = cast_data<float>(dst);
969
- exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
970
- break;
971
- }
972
- default:
973
- GGML_ABORT("GGML tensor type not supported!\n");
974
- }
975
- }
976
-
977
- inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
978
- #if defined (GGML_SYCL_F16)
979
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
980
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
981
- #else
982
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
983
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
984
- #endif
985
- GGML_ASSERT(dst->src[0]->type == dst->type);
986
- dpct::queue_ptr main_stream = ctx.stream();
987
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
988
- switch (dst->type) {
989
- #if defined (GGML_SYCL_F16)
990
- case GGML_TYPE_F16:
991
- {
992
- auto data_pts = cast_data<sycl::half>(dst);
993
- log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
994
- break;
995
- }
996
- #endif
997
- case GGML_TYPE_F32:
998
- {
999
- auto data_pts = cast_data<float>(dst);
1000
- log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1001
- break;
1002
- }
1003
- default:
1004
- GGML_ABORT("GGML tensor type not supported!\n");
1005
- }
1006
- }
1007
-
1008
- inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1009
- #if defined (GGML_SYCL_F16)
1010
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1011
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1012
- #else
1013
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1014
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1015
- #endif
1016
- GGML_ASSERT(dst->src[0]->type == dst->type);
1017
- dpct::queue_ptr main_stream = ctx.stream();
1018
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1019
- switch (dst->type) {
1020
- #if defined (GGML_SYCL_F16)
1021
- case GGML_TYPE_F16:
1022
- {
1023
- auto data_pts = cast_data<sycl::half>(dst);
1024
- sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1025
- break;
1026
- }
1027
- #endif
1028
- case GGML_TYPE_F32:
1029
- {
1030
- auto data_pts = cast_data<float>(dst);
1031
- sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1032
- break;
1033
- }
1034
- default:
1035
- GGML_ABORT("GGML tensor type not supported!\n");
1036
- }
651
+ static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
652
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
653
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
654
+ const int num_blocks = ceil_div(k_elements, 256);
655
+ stream->parallel_for(
656
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
657
+ sycl::range<1>(256)),
658
+ [=](sycl::nd_item<1> item_ct1) {
659
+ unary_op_abs_kernel(src, dst_ptr, k_elements, item_ct1);
660
+ });
661
+ });
1037
662
  }
1038
663
 
1039
- inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1040
- #if defined (GGML_SYCL_F16)
1041
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1042
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1043
- #else
1044
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1045
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1046
- #endif
1047
- GGML_ASSERT(dst->src[0]->type == dst->type);
1048
-
1049
- dpct::queue_ptr main_stream = ctx.stream();
1050
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1051
- switch (dst->type) {
1052
- #if defined (GGML_SYCL_F16)
1053
- case GGML_TYPE_F16:
1054
- {
1055
- auto data_pts = cast_data<sycl::half>(dst);
1056
- sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1057
- break;
1058
- }
1059
- #endif
1060
- case GGML_TYPE_F32:
1061
- {
1062
- auto data_pts = cast_data<float>(dst);
1063
- sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1064
- break;
1065
- }
1066
- default:
1067
- GGML_ABORT("GGML tensor type not supported!\n");
1068
- }
664
+ static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
665
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
666
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
667
+ const int num_blocks = ceil_div(k_elements, 256);
668
+ stream->parallel_for(
669
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
670
+ sycl::range<1>(256)),
671
+ [=](sycl::nd_item<1> item_ct1) {
672
+ unary_op_elu_kernel(src, dst_ptr, k_elements, item_ct1);
673
+ });
674
+ });
1069
675
  }
1070
676
 
1071
- inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1072
- #if defined (GGML_SYCL_F16)
1073
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1074
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1075
- #else
1076
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1077
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1078
- #endif
1079
- GGML_ASSERT(dst->src[0]->type == dst->type);
1080
- dpct::queue_ptr main_stream = ctx.stream();
1081
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1082
- switch (dst->type) {
1083
- #if defined (GGML_SYCL_F16)
1084
- case GGML_TYPE_F16:
1085
- {
1086
- auto data_pts = cast_data<sycl::half>(dst);
1087
- sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1088
- break;
1089
- }
1090
- #endif
1091
- case GGML_TYPE_F32:
1092
- {
1093
- auto data_pts = cast_data<float>(dst);
1094
- sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1095
- break;
1096
- }
1097
- default:
1098
- GGML_ABORT("GGML tensor type not supported!\n");
1099
- }
677
+ static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
678
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
679
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
680
+ const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
681
+ stream->parallel_for(
682
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
683
+ sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
684
+ [=](sycl::nd_item<1> item_ct1) {
685
+ unary_op_silu_kernel(src, dst_ptr, k_elements, item_ct1);
686
+ });
687
+ });
1100
688
  }
1101
689
 
1102
- inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1103
- #if defined (GGML_SYCL_F16)
1104
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1105
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1106
- #else
1107
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1108
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1109
- #endif
1110
- GGML_ASSERT(dst->src[0]->type == dst->type);
1111
- dpct::queue_ptr main_stream = ctx.stream();
1112
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1113
- switch (dst->type) {
1114
- #if defined (GGML_SYCL_F16)
1115
- case GGML_TYPE_F16:
1116
- {
1117
- auto data_pts = cast_data<sycl::half>(dst);
1118
- cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1119
- break;
1120
- }
1121
- #endif
1122
- case GGML_TYPE_F32:
1123
- {
1124
- auto data_pts = cast_data<float>(dst);
1125
- cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1126
- break;
1127
- }
1128
- default:
1129
- GGML_ABORT("GGML tensor type not supported!\n");
1130
- }
690
+ static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
691
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
692
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
693
+ const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
694
+ stream->parallel_for(
695
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
696
+ sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
697
+ [=](sycl::nd_item<1> item_ct1) {
698
+ unary_op_gelu_kernel(src, dst_ptr, k_elements, item_ct1);
699
+ });
700
+ });
1131
701
  }
1132
702
 
1133
- inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1134
- #if defined (GGML_SYCL_F16)
1135
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1136
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1137
- #else
1138
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1139
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1140
- #endif
1141
- GGML_ASSERT(dst->src[0]->type == dst->type);
1142
- dpct::queue_ptr main_stream = ctx.stream();
1143
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1144
- switch (dst->type) {
1145
- #if defined (GGML_SYCL_F16)
1146
- case GGML_TYPE_F16:
1147
- {
1148
- auto data_pts = cast_data<sycl::half>(dst);
1149
- step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1150
- break;
1151
- }
1152
- #endif
1153
- case GGML_TYPE_F32:
1154
- {
1155
- auto data_pts = cast_data<float>(dst);
1156
- step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1157
- break;
1158
- }
1159
- default:
1160
- GGML_ABORT("GGML tensor type not supported!\n");
1161
- }
703
+ static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
704
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
705
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
706
+ const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
707
+ stream->parallel_for(
708
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
709
+ sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
710
+ [=](sycl::nd_item<1> item_ct1) {
711
+ unary_op_gelu_quick_kernel(src, dst_ptr, k_elements, item_ct1);
712
+ });
713
+ });
1162
714
  }
1163
715
 
1164
- inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1165
- #if defined (GGML_SYCL_F16)
1166
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1167
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1168
- #else
1169
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1170
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1171
- #endif
1172
- GGML_ASSERT(dst->src[0]->type == dst->type);
1173
- dpct::queue_ptr main_stream = ctx.stream();
1174
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1175
- switch (dst->type) {
1176
- #if defined (GGML_SYCL_F16)
1177
- case GGML_TYPE_F16:
1178
- {
1179
- auto data_pts = cast_data<sycl::half>(dst);
1180
- neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1181
- break;
1182
- }
1183
- #endif
1184
- case GGML_TYPE_F32:
1185
- {
1186
- auto data_pts = cast_data<float>(dst);
1187
- neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1188
- break;
1189
- }
1190
- default:
1191
- GGML_ABORT("GGML tensor type not supported!\n");
1192
- }
716
+ static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
717
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
718
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
719
+ const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
720
+ stream->parallel_for(
721
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
722
+ sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
723
+ [=](sycl::nd_item<1> item_ct1) {
724
+ unary_op_gelu_erf_kernel(src, dst_ptr, k_elements, item_ct1);
725
+ });
726
+ });
1193
727
  }
1194
728
 
1195
- inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1196
- #if defined (GGML_SYCL_F16)
1197
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1198
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1199
- #else
1200
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1201
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1202
- #endif
729
+ static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
730
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
731
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
732
+ const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
733
+ stream->parallel_for(
734
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
735
+ sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
736
+ [=](sycl::nd_item<1> item_ct1) {
737
+ unary_op_tanh_kernel(src, dst_ptr, k_elements, item_ct1);
738
+ });
739
+ });
740
+ }
1203
741
 
1204
- GGML_ASSERT(dst->src[0]->type == dst->type);
1205
- float negative_slope;
1206
- memcpy(&negative_slope, dst->op_params, sizeof(float));
1207
- dpct::queue_ptr main_stream = ctx.stream();
1208
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1209
- switch (dst->type) {
1210
- #if defined (GGML_SYCL_F16)
1211
- case GGML_TYPE_F16:
1212
- {
1213
- auto data_pts = cast_data<sycl::half>(dst);
1214
- leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream);
1215
- break;
1216
- }
1217
- #endif
1218
- case GGML_TYPE_F32:
1219
- {
1220
- auto data_pts = cast_data<float>(dst);
1221
- leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream);
1222
- break;
1223
- }
1224
- default:
1225
- GGML_ABORT("GGML tensor type not supported!\n");
1226
- }
742
+ static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
743
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
744
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
745
+ const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
746
+ stream->parallel_for(
747
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
748
+ sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
749
+ [=](sycl::nd_item<1> item_ct1) {
750
+ unary_op_relu_kernel(src, dst_ptr, k_elements, item_ct1);
751
+ });
752
+ });
1227
753
  }
1228
754
 
1229
- inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1230
- #if defined (GGML_SYCL_F16)
1231
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1232
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1233
- #else
1234
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1235
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1236
- #endif
1237
- GGML_ASSERT(dst->src[0]->type == dst->type);
1238
- dpct::queue_ptr main_stream = ctx.stream();
1239
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1240
- switch (dst->type) {
1241
- #if defined (GGML_SYCL_F16)
1242
- case GGML_TYPE_F16:
1243
- {
1244
- auto data_pts = cast_data<sycl::half>(dst);
1245
- sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1246
- break;
1247
- }
1248
- #endif
1249
- case GGML_TYPE_F32:
1250
- {
1251
- auto data_pts = cast_data<float>(dst);
1252
- sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
1253
- break;
1254
- }
1255
- default:
1256
- GGML_ABORT("GGML tensor type not supported!\n");
1257
- }
755
+ static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
756
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
757
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
758
+ const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
759
+ stream->parallel_for(
760
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
761
+ sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
762
+ [=](sycl::nd_item<1> item_ct1) {
763
+ unary_op_hardsigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
764
+ });
765
+ });
1258
766
  }
1259
767
 
1260
- inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1261
- #if defined (GGML_SYCL_F16)
1262
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1263
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1264
- #else
1265
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1266
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1267
- #endif
1268
- GGML_ASSERT(dst->src[0]->type == dst->type);
768
+ static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
769
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
770
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
771
+ const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
772
+ stream->parallel_for(
773
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
774
+ sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
775
+ [=](sycl::nd_item<1> item_ct1) {
776
+ unary_op_hardswish_kernel(src, dst_ptr, k_elements, item_ct1);
777
+ });
778
+ });
779
+ }
1269
780
 
1270
- dpct::queue_ptr main_stream = ctx.stream();
1271
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
781
+ static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
782
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
783
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
784
+ const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
785
+ stream->parallel_for(
786
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
787
+ sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
788
+ [=](sycl::nd_item<1> item_ct1) {
789
+ unary_op_exp_kernel(src, dst_ptr, k_elements, item_ct1);
790
+ });
791
+ });
792
+ }
1272
793
 
1273
- const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];
1274
- const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];
1275
- const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
1276
- const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
1277
- switch (dst->type) {
1278
- #if defined (GGML_SYCL_F16)
1279
- case GGML_TYPE_F16:
1280
- {
1281
- auto data_pts = cast_data<sycl::half>(dst);
1282
- upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2],
1283
- dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
1284
- main_stream);
1285
- break;
1286
- }
1287
- #endif
1288
- case GGML_TYPE_F32:
1289
- {
1290
- auto data_pts = cast_data<float>(dst);
1291
- upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2],
1292
- dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
1293
- main_stream);
1294
- break;
1295
- }
1296
- default:
1297
- GGML_ABORT("GGML tensor type not supported!\n");
1298
- }
794
+ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
795
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
796
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
797
+ const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size
798
+ stream->parallel_for(
799
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
800
+ sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
801
+ [=](sycl::nd_item<1> item_ct1) {
802
+ unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
803
+ });
804
+ });
1299
805
  }
1300
806
 
1301
- inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1302
- #if defined (GGML_SYCL_F16)
1303
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1304
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1305
- #else
1306
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1307
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1308
- #endif
1309
- GGML_ASSERT(dst->src[0]->type == dst->type);
1310
- GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
1311
- dpct::queue_ptr main_stream = ctx.stream();
1312
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1313
- switch (dst->type) {
1314
- #if defined (GGML_SYCL_F16)
1315
- case GGML_TYPE_F16:
1316
- {
1317
- auto data_pts = cast_data<sycl::half>(dst);
1318
- pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0],
1319
- dst->ne[1], dst->ne[2], main_stream);
1320
- break;
1321
- }
1322
- #endif
1323
- case GGML_TYPE_F32:
1324
- {
1325
- auto data_pts = cast_data<float>(dst);
1326
- pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0],
1327
- dst->ne[1], dst->ne[2], main_stream);
1328
- break;
1329
- }
1330
- default:
1331
- GGML_ABORT("GGML tensor type not supported!\n");
1332
- }
807
+ static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
808
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
809
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
810
+ const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
811
+ stream->parallel_for(
812
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
813
+ sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
814
+ [=](sycl::nd_item<1> item_ct1) {
815
+ unary_op_neg_kernel(src, dst_ptr, k_elements, item_ct1);
816
+ });
817
+ });
1333
818
  }
1334
819
 
1335
- inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1336
- #if defined(GGML_SYCL_F16)
1337
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
1338
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1339
- #else
820
+ static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
821
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
822
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
823
+ const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
824
+ stream->parallel_for(
825
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
826
+ sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
827
+ [=](sycl::nd_item<1> item_ct1) {
828
+ unary_op_step_kernel(src, dst_ptr, k_elements, item_ct1);
829
+ });
830
+ });
831
+ }
1340
832
 
1341
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1342
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
1343
- #endif
1344
- GGML_ASSERT(dst->src[0]->type == dst->type);
1345
- dpct::queue_ptr main_stream = ctx.stream();
1346
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
1347
- float min;
1348
- float max;
1349
- memcpy(&min, dst->op_params, sizeof(float));
1350
- memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
833
+ static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
834
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
835
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
836
+ const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
837
+ stream->parallel_for(
838
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
839
+ sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
840
+ [=](sycl::nd_item<1> item_ct1) {
841
+ unary_op_sigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
842
+ });
843
+ });
844
+ }
1351
845
 
1352
- switch (dst->type) {
1353
- #if defined(GGML_SYCL_F16)
1354
- case GGML_TYPE_F16:
1355
- {
1356
- auto data_pts = cast_data<sycl::half>(dst);
1357
- clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream);
1358
- break;
1359
- }
1360
- #endif
1361
- case GGML_TYPE_F32:
1362
- {
1363
- auto data_pts = cast_data<float>(dst);
1364
- clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream);
1365
- break;
1366
- }
1367
- default:
1368
- GGML_ABORT("GGML tensor type not supported!\n");
1369
- }
846
+ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
847
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
848
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
849
+ const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE);
850
+ stream->parallel_for(
851
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
852
+ sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
853
+ [=](sycl::nd_item<1> item_ct1) {
854
+ unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
855
+ });
856
+ });
857
+ }
858
+
859
+ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
860
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
861
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
862
+ const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE);
863
+ stream->parallel_for(
864
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
865
+ sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
866
+ [=](sycl::nd_item<1> item_ct1) {
867
+ unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
868
+ });
869
+ });
870
+ }
871
+
872
+ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
873
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
874
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
875
+ const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size
876
+ stream->parallel_for(
877
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
878
+ sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
879
+ [=](sycl::nd_item<1> item_ct1) {
880
+ unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
881
+ });
882
+ });
1370
883
  }
1371
884
 
1372
- inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
885
+ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
886
+ float negative_slope;
887
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
888
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
889
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) {
890
+ const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
891
+ stream->parallel_for(
892
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
893
+ sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
894
+ [=](sycl::nd_item<1> item_ct1) {
895
+ unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
896
+ });
897
+ }, negative_slope);
898
+ }
899
+
900
+ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
901
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
902
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
903
+ const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE);
904
+ stream->parallel_for(
905
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
906
+ sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
907
+ [=](sycl::nd_item<1> item_ct1) {
908
+ unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
909
+ });
910
+ });
911
+ }
912
+
913
+ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
914
+ ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst,
915
+ [](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03,
916
+ int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3,
917
+ queue_ptr stream) {
918
+ ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream);
919
+ });
920
+ }
921
+
922
+ static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
923
+ ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
924
+ [](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
925
+ queue_ptr stream) {
926
+ ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
927
+ });
928
+ }
1373
929
 
930
+ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
931
+ float min_val;
932
+ float max_val;
933
+ memcpy(&min_val, dst->op_params, sizeof(float));
934
+ memcpy(&max_val, (float *) dst->op_params + 1, sizeof(float));
935
+ ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
936
+ [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) {
937
+ const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE);
938
+ stream->parallel_for(
939
+ sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
940
+ sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
941
+ [=](sycl::nd_item<1> item_ct1) {
942
+ clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
943
+ });
944
+ }, min_val, max_val);
945
+ }
946
+
947
+ static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
1374
948
  GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
1375
949
  GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
1376
950
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -1386,7 +960,62 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
1386
960
  // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
1387
961
  int offset = dst->op_params[3] / 4; // offset in bytes
1388
962
 
1389
- acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
963
+ ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
964
+ }
965
+
966
+ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
967
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
968
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
969
+ const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
970
+ main_stream->parallel_for(
971
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
972
+ gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
973
+ });
974
+ });
975
+ }
976
+
977
+ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
978
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
979
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
980
+ const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
981
+ main_stream->parallel_for(
982
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
983
+ gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
984
+ });
985
+ });
986
+ }
987
+
988
+ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
989
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
990
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
991
+ const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
992
+ main_stream->parallel_for(
993
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
994
+ gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
995
+ });
996
+ });
997
+ }
998
+
999
+ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1000
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
1001
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
1002
+ const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
1003
+ main_stream->parallel_for(
1004
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
1005
+ gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
1006
+ });
1007
+ });
1008
+ }
1009
+
1010
+ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1011
+ ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
1012
+ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
1013
+ const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
1014
+ main_stream->parallel_for(
1015
+ sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
1016
+ gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
1017
+ });
1018
+ });
1390
1019
  }
1391
1020
 
1392
1021
 
@@ -1425,6 +1054,11 @@ void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1425
1054
  ggml_sycl_op_gelu_quick(ctx, dst);
1426
1055
  }
1427
1056
 
1057
+ void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1058
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1059
+ ggml_sycl_op_gelu_erf(ctx, dst);
1060
+ }
1061
+
1428
1062
  void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1429
1063
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1430
1064
  ggml_sycl_op_tanh(ctx, dst);
@@ -1509,3 +1143,28 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1509
1143
  scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1510
1144
  ggml_sycl_op_elu(ctx, dst);
1511
1145
  }
1146
+
1147
+ void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1148
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1149
+ ggml_sycl_op_geglu(ctx, dst);
1150
+ }
1151
+
1152
+ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1153
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1154
+ ggml_sycl_op_reglu(ctx, dst);
1155
+ }
1156
+
1157
+ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1158
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1159
+ ggml_sycl_op_swiglu(ctx, dst);
1160
+ }
1161
+
1162
+ void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1163
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1164
+ ggml_sycl_op_geglu_erf(ctx, dst);
1165
+ }
1166
+
1167
+ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1168
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1169
+ ggml_sycl_op_geglu_quick(ctx, dst);
1170
+ }