whispercpp 1.3.1 → 1.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (797) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +4 -3
  3. data/README.md +92 -31
  4. data/Rakefile +26 -7
  5. data/ext/.gitignore +5 -7
  6. data/ext/dependencies.rb +61 -0
  7. data/ext/extconf.rb +21 -198
  8. data/ext/options.rb +221 -0
  9. data/ext/ruby_whisper.c +159 -0
  10. data/ext/ruby_whisper.h +17 -2
  11. data/ext/ruby_whisper_context.c +641 -0
  12. data/ext/ruby_whisper_error.c +52 -0
  13. data/ext/ruby_whisper_model.c +232 -0
  14. data/ext/ruby_whisper_params.c +1301 -0
  15. data/ext/ruby_whisper_segment.c +143 -0
  16. data/ext/ruby_whisper_transcribe.cpp +87 -0
  17. data/ext/ruby_whisper_vad_params.c +288 -0
  18. data/ext/sources/.dockerignore +3 -0
  19. data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
  20. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  21. data/ext/sources/CMakeLists.txt +251 -0
  22. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  23. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  24. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  25. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  26. data/ext/sources/bindings/javascript/package.json +26 -0
  27. data/ext/sources/bindings/javascript/whisper.js +19 -0
  28. data/ext/sources/build-xcframework.sh +547 -0
  29. data/ext/sources/ci/run.sh +336 -0
  30. data/ext/sources/close-issue.yml +28 -0
  31. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  32. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  33. data/ext/sources/cmake/build-info.cmake +60 -0
  34. data/ext/sources/cmake/git-vars.cmake +22 -0
  35. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  36. data/ext/sources/cmake/whisper.pc.in +10 -0
  37. data/ext/sources/examples/CMakeLists.txt +124 -0
  38. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  39. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
  40. data/ext/sources/examples/addon.node/addon.cpp +438 -0
  41. data/ext/sources/examples/addon.node/index.js +54 -0
  42. data/ext/sources/examples/addon.node/package.json +16 -0
  43. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  44. data/ext/sources/examples/bench/bench.cpp +175 -0
  45. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  46. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  47. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  48. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  49. data/ext/sources/examples/cli/cli.cpp +1294 -0
  50. data/ext/sources/examples/coi-serviceworker.js +146 -0
  51. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  52. data/ext/sources/examples/command/command.cpp +776 -0
  53. data/ext/sources/examples/command/commands.txt +9 -0
  54. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  55. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  56. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  57. data/ext/sources/examples/common-ggml.cpp +238 -0
  58. data/ext/sources/examples/common-ggml.h +18 -0
  59. data/ext/sources/examples/common-sdl.cpp +227 -0
  60. data/ext/sources/examples/common-sdl.h +49 -0
  61. data/ext/sources/examples/common-whisper.cpp +168 -0
  62. data/ext/sources/examples/common-whisper.h +24 -0
  63. data/ext/sources/examples/common.cpp +675 -0
  64. data/ext/sources/examples/common.h +322 -0
  65. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  66. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  67. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  68. data/ext/sources/examples/generate-karaoke.sh +57 -0
  69. data/ext/sources/examples/grammar-parser.cpp +423 -0
  70. data/ext/sources/examples/grammar-parser.h +29 -0
  71. data/ext/sources/examples/helpers.js +191 -0
  72. data/ext/sources/examples/json.hpp +24596 -0
  73. data/ext/sources/examples/livestream.sh +112 -0
  74. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  75. data/ext/sources/examples/lsp/lsp.cpp +467 -0
  76. data/ext/sources/examples/lsp/whisper.vim +362 -0
  77. data/ext/sources/examples/miniaudio.h +93468 -0
  78. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  79. data/ext/sources/examples/python/whisper_processor.py +54 -0
  80. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  81. data/ext/sources/examples/quantize/quantize.cpp +223 -0
  82. data/ext/sources/examples/server/CMakeLists.txt +12 -0
  83. data/ext/sources/examples/server/bench.js +29 -0
  84. data/ext/sources/examples/server/httplib.h +10497 -0
  85. data/ext/sources/examples/server/server.cpp +1091 -0
  86. data/ext/sources/examples/server.py +115 -0
  87. data/ext/sources/examples/stb_vorbis.c +5584 -0
  88. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  89. data/ext/sources/examples/stream/stream.cpp +429 -0
  90. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  91. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  92. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  93. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  94. data/ext/sources/examples/sycl/build.sh +22 -0
  95. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  96. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  97. data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
  98. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  99. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  100. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  101. data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
  102. data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
  103. data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
  104. data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
  105. data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
  106. data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
  107. data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
  108. data/ext/sources/examples/talk-llama/llama-context.h +276 -0
  109. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  110. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  111. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  112. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  113. data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
  114. data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
  115. data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
  116. data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
  117. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  118. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  119. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  120. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  121. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
  124. data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
  125. data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
  126. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  127. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  128. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
  129. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  130. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
  131. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  132. data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
  133. data/ext/sources/examples/talk-llama/llama-model.h +425 -0
  134. data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
  135. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  136. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  137. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  138. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
  139. data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
  140. data/ext/sources/examples/talk-llama/llama.cpp +354 -0
  141. data/ext/sources/examples/talk-llama/llama.h +1377 -0
  142. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  143. data/ext/sources/examples/talk-llama/speak +40 -0
  144. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  145. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  146. data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
  147. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  148. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  149. data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
  150. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  151. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  152. data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
  153. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  154. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  155. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  156. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  157. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  158. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  159. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  160. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  161. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
  162. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  163. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  164. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  165. data/ext/sources/ggml/CMakeLists.txt +390 -0
  166. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  167. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  168. data/ext/sources/ggml/cmake/common.cmake +26 -0
  169. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  170. data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
  171. data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
  172. data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
  173. data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
  174. data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
  175. data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
  176. data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
  177. data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
  178. data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
  179. data/ext/sources/ggml/include/gguf.h +202 -0
  180. data/ext/sources/ggml/src/CMakeLists.txt +346 -0
  181. data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
  182. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  183. data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
  184. data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
  185. data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
  186. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  187. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
  188. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  189. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
  190. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
  191. data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  193. data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
  194. data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
  195. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  196. data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
  197. data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
  198. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
  199. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
  200. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  201. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  202. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  203. data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
  204. data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
  205. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
  206. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
  207. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
  209. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  213. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
  218. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  219. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  220. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  221. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
  222. data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
  223. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  224. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  225. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  227. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  229. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  231. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  232. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  233. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  234. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  235. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  236. data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
  237. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  238. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  239. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  240. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  241. data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
  242. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
  243. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  244. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  245. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  246. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  247. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  248. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  249. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  251. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  252. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  254. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
  255. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  256. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  257. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  258. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  259. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  260. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  261. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
  262. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  263. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  264. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  265. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  266. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  267. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
  268. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  269. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  270. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  271. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  272. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  273. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  274. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  275. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
  276. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
  277. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  278. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  279. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  280. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  281. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  282. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  284. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  285. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  286. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  287. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  288. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  289. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  290. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  291. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  292. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  293. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  294. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  295. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  296. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  297. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  298. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
  300. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  301. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  302. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  303. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
  304. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
  305. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  306. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  307. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  308. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  309. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  310. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  311. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  312. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  313. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  314. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  315. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  316. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  317. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  334. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  430. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
  432. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
  433. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  434. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  436. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
  437. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
  438. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
  439. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  440. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  441. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
  442. data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
  443. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  444. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  445. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  446. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  447. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  448. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  449. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  450. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  451. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  452. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  453. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  454. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  455. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  456. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  457. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  458. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  459. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  460. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  461. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  462. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  463. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  464. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  465. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  466. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  467. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  468. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  469. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  470. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  471. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  472. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  473. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  474. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  481. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
  482. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
  483. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
  484. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
  485. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  486. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  487. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  488. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
  489. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  526. data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
  527. data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
  528. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  529. data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
  530. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  531. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  532. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
  533. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  534. data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
  535. data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
  536. data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
  537. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  538. data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
  539. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  540. data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
  541. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  542. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
  543. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  544. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
  545. data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
  546. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  547. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
  548. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
  549. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
  550. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
  551. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
  552. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  553. data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
  554. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  555. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  556. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  557. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  558. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
  559. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  560. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
  561. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  562. data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
  563. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  564. data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
  565. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  566. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  567. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
  568. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
  569. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  570. data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
  571. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  572. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  573. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  574. data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
  575. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  576. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
  577. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
  578. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  579. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
  580. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  581. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  692. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  695. data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
  696. data/ext/sources/ggml/src/gguf.cpp +1330 -0
  697. data/ext/{include → sources/include}/whisper.h +68 -2
  698. data/ext/sources/src/CMakeLists.txt +143 -0
  699. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
  700. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
  701. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
  702. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
  703. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  704. data/ext/sources/src/whisper-arch.h +197 -0
  705. data/ext/{src → sources/src}/whisper.cpp +1905 -374
  706. data/ext/sources/tests/CMakeLists.txt +105 -0
  707. data/ext/sources/tests/earnings21/eval.mk +58 -0
  708. data/ext/sources/tests/earnings21/eval.py +68 -0
  709. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  710. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  711. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  712. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  713. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  714. data/ext/sources/tests/en-0-ref.txt +1 -0
  715. data/ext/sources/tests/en-1-ref.txt +1 -0
  716. data/ext/sources/tests/en-2-ref.txt +1 -0
  717. data/ext/sources/tests/es-0-ref.txt +1 -0
  718. data/ext/sources/tests/librispeech/eval.mk +39 -0
  719. data/ext/sources/tests/librispeech/eval.py +47 -0
  720. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  721. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  722. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  723. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  724. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  725. data/ext/sources/tests/run-tests.sh +130 -0
  726. data/ext/sources/tests/test-c.c +3 -0
  727. data/ext/sources/tests/test-vad-full.cpp +54 -0
  728. data/ext/sources/tests/test-vad.cpp +83 -0
  729. data/ext/sources/tests/test-whisper.js +58 -0
  730. data/extsources.rb +33 -5
  731. data/lib/whisper/model/uri.rb +149 -128
  732. data/sig/whisper.rbs +480 -0
  733. data/tests/helper.rb +28 -0
  734. data/tests/test_callback.rb +45 -3
  735. data/tests/test_error.rb +2 -2
  736. data/tests/test_model.rb +38 -0
  737. data/tests/test_package.rb +18 -3
  738. data/tests/test_params.rb +145 -8
  739. data/tests/test_segment.rb +10 -19
  740. data/tests/test_vad.rb +19 -0
  741. data/tests/test_vad_params.rb +103 -0
  742. data/tests/test_whisper.rb +37 -37
  743. data/whispercpp.gemspec +5 -4
  744. metadata +766 -111
  745. data/ext/cpu.mk +0 -9
  746. data/ext/examples/dr_wav.h +0 -8815
  747. data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
  748. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
  749. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
  750. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
  751. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
  752. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
  753. data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
  754. data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
  755. data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
  756. data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
  757. data/ext/metal-embed.mk +0 -17
  758. data/ext/metal.mk +0 -6
  759. data/ext/ruby_whisper.cpp +0 -1909
  760. data/ext/scripts/get-flags.mk +0 -38
  761. data/lib/whisper.rb +0 -2
  762. /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
  763. /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
  764. /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
  765. /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
  766. /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
  767. /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
  768. /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
  769. /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
  770. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
  771. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
  772. /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
  773. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
  774. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
  775. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
  776. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
  777. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
  778. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
  779. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
  780. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
  781. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
  782. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
  783. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
  784. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
  785. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
  786. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
  787. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
  788. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
  789. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
  790. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
  791. /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
  792. /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
  793. /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
  794. /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
  795. /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
  796. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
  797. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -0,0 +1,2957 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_DPCT_HELPER_HPP
14
+ #define GGML_SYCL_DPCT_HELPER_HPP
15
+
16
+ #include <sycl/sycl.hpp>
17
+ #include <sycl/half_type.hpp>
18
+ #include <syclcompat/math.hpp>
19
+ #include <map>
20
+
21
+ #ifdef GGML_SYCL_USE_INTEL_ONEMKL
22
+ #include <oneapi/mkl.hpp>
23
+ // Allow to use the same namespace for Intel oneMKL and oneMath
24
+ namespace oneapi {
25
+ namespace math = mkl;
26
+ }
27
+ #else
28
+ #include <oneapi/math.hpp>
29
+ #endif
30
+
31
+ #include "ggml.h"
32
+
33
+ #if defined(__linux__)
34
+ #include <sys/mman.h>
35
+ #elif defined(_WIN64)
36
+ #ifndef NOMINMAX
37
+ #define NOMINMAX
38
+ #endif
39
+ #include <windows.h>
40
+ #else
41
+ #error "Only support Windows and Linux."
42
+ #endif
43
+
44
+ #if defined(__linux__)
45
+ #include <unistd.h>
46
+ #include <sys/syscall.h>
47
+ #endif
48
+ #if defined(_WIN64)
49
+ #ifndef NOMINMAX
50
+ #define NOMINMAX
51
+ #endif
52
+ #include <windows.h>
53
+ #endif
54
+
55
+ #define DPCT_COMPATIBILITY_TEMP (900)
56
+
57
+ #if defined(_MSC_VER)
58
+ #define __dpct_align__(n) __declspec(align(n))
59
+ #define __dpct_inline__ __forceinline
60
+ #else
61
+ #define __dpct_align__(n) __attribute__((aligned(n)))
62
+ #define __dpct_inline__ __inline__ __attribute__((always_inline))
63
+ #endif
64
+
65
+ #if defined(_MSC_VER)
66
+ #define __dpct_noinline__ __declspec(noinline)
67
+ #else
68
+ #define __dpct_noinline__ __attribute__((noinline))
69
+ #endif
70
+
71
+ inline std::string get_device_type_name(const sycl::device &Device) {
72
+ auto DeviceType = Device.get_info<sycl::info::device::device_type>();
73
+ switch (DeviceType) {
74
+ case sycl::info::device_type::cpu:
75
+ return "cpu";
76
+ case sycl::info::device_type::gpu:
77
+ return "gpu";
78
+ case sycl::info::device_type::host:
79
+ return "host";
80
+ case sycl::info::device_type::accelerator:
81
+ return "acc";
82
+ default:
83
+ return "unknown";
84
+ }
85
+ }
86
+
87
+ inline std::string get_device_backend_and_type(const sycl::device &device) {
88
+ std::stringstream device_type;
89
+ sycl::backend backend = device.get_backend();
90
+ device_type << backend << ":" << get_device_type_name(device);
91
+ return device_type.str();
92
+ }
93
+
94
+ template <typename Ts> struct matrix_info_t {
95
+ oneapi::math::transpose transpose_info[2];
96
+ Ts value_info[2];
97
+ std::int64_t size_info[3];
98
+ std::int64_t ld_info[3];
99
+ std::int64_t groupsize_info;
100
+ };
101
+
102
+ inline auto get_onemath_backend(sycl::queue& queue)
103
+ #if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
104
+ -> sycl::queue&
105
+ #endif
106
+ {
107
+ // If the backend is known at compile-time, use oneMath backend_selector to use
108
+ // compile-time dispatching and avoid the need to dlopen libraries. Otherwise
109
+ // fallback to runtime dispatching.
110
+ #if defined(GGML_SYCL_NVIDIA)
111
+ return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
112
+ #elif defined(GGML_SYCL_AMD)
113
+ return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
114
+ #elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
115
+ return queue;
116
+ #else
117
+ static_assert(false, "Unsupported backend");
118
+ #endif
119
+ }
120
+
121
+ namespace dpct
122
+ {
123
+ typedef sycl::queue *queue_ptr;
124
+ typedef sycl::event *event_ptr;
125
+ typedef char *device_ptr;
126
+ typedef uint8_t byte_t;
127
+ typedef sycl::buffer<byte_t> buffer_t;
128
+
129
+ /// SYCL default exception handler
130
+ inline auto exception_handler = [](sycl::exception_list exceptions)
131
+ {
132
+ for (std::exception_ptr const &e : exceptions)
133
+ {
134
+ try
135
+ {
136
+ std::rethrow_exception(e);
137
+ }
138
+ catch (sycl::exception const &e)
139
+ {
140
+ std::cerr << "Caught asynchronous SYCL exception:" << std::endl
141
+ << e.what() << std::endl
142
+ << "Exception caught at file:" << __FILE__
143
+ << ", line:" << __LINE__ << std::endl;
144
+ }
145
+ }
146
+ };
147
+
148
+ enum error_code
149
+ {
150
+ success = 0,
151
+ default_error = 999
152
+ };
153
+
154
+ enum memcpy_direction
155
+ {
156
+ host_to_host,
157
+ host_to_device,
158
+ device_to_host,
159
+ device_to_device,
160
+ automatic
161
+ };
162
+
163
+ enum memory_region
164
+ {
165
+ global = 0, // device global memory
166
+ constant, // device constant memory
167
+ local, // device local memory
168
+ shared, // memory which can be accessed by host and device
169
+ };
170
+
171
+ enum class library_data_t : unsigned char
172
+ {
173
+ real_float = 0,
174
+ complex_float,
175
+ real_double,
176
+ complex_double,
177
+ real_half,
178
+ complex_half,
179
+ real_bfloat16,
180
+ complex_bfloat16,
181
+ real_int4,
182
+ complex_int4,
183
+ real_uint4,
184
+ complex_uint4,
185
+ real_int8,
186
+ complex_int8,
187
+ real_uint8,
188
+ complex_uint8,
189
+ real_int16,
190
+ complex_int16,
191
+ real_uint16,
192
+ complex_uint16,
193
+ real_int32,
194
+ complex_int32,
195
+ real_uint32,
196
+ complex_uint32,
197
+ real_int64,
198
+ complex_int64,
199
+ real_uint64,
200
+ complex_uint64,
201
+ real_int8_4,
202
+ real_int8_32,
203
+ real_uint8_4,
204
+ library_data_t_size
205
+ };
206
+
207
+ template <typename T>
208
+ struct DataType
209
+ {
210
+ using T2 = T;
211
+ };
212
+ template <typename T>
213
+ struct DataType<sycl::vec<T, 2>>
214
+ {
215
+ using T2 = std::complex<T>;
216
+ };
217
+
218
+ static void destroy_event(event_ptr event)
219
+ {
220
+ delete event;
221
+ }
222
+
223
+ static inline unsigned int get_tid()
224
+ {
225
+ #if defined(__linux__)
226
+ return syscall(SYS_gettid);
227
+ #elif defined(_WIN64)
228
+ return GetCurrentThreadId();
229
+ #else
230
+ #error "Only support Windows and Linux."
231
+ #endif
232
+ }
233
+
234
+ namespace detail
235
+ {
236
+ static void get_version(const sycl::device &dev, int &major, int &minor)
237
+ {
238
+ // Version string has the following format:
239
+ // a. OpenCL<space><major.minor><space><vendor-specific-information>
240
+ // b. <major.minor>
241
+ // c. <AmdGcnArchName> e.g gfx1030
242
+ std::string ver;
243
+ ver = dev.get_info<sycl::info::device::version>();
244
+ std::string::size_type i = 0;
245
+ while (i < ver.size()) {
246
+ if (isdigit(ver[i]))
247
+ break;
248
+ i++;
249
+ }
250
+ major = std::stoi(&(ver[i]));
251
+ while (i < ver.size()) {
252
+ if (ver[i] == '.')
253
+ break;
254
+ i++;
255
+ }
256
+ if (i < ver.size()) {
257
+ // a. and b.
258
+ i++;
259
+ minor = std::stoi(&(ver[i]));
260
+ } else {
261
+ // c.
262
+ minor = 0;
263
+ }
264
+ }
265
+
266
+ template <typename tag, typename T>
267
+ class generic_error_type
268
+ {
269
+ public:
270
+ generic_error_type() = default;
271
+ generic_error_type(T value) : value{value} {}
272
+ operator T() const { return value; }
273
+
274
+ private:
275
+ T value;
276
+ };
277
+
278
+ } // namespace detail
279
+
280
+ /// Pitched 2D/3D memory data.
281
+ class pitched_data
282
+ {
283
+ public:
284
+ pitched_data() : pitched_data(nullptr, 0, 0, 0) {}
285
+ pitched_data(void *data, size_t pitch, size_t x, size_t y)
286
+ : _data(data), _pitch(pitch), _x(x), _y(y) {}
287
+
288
+ void *get_data_ptr() { return _data; }
289
+ void set_data_ptr(void *data) { _data = data; }
290
+
291
+ size_t get_pitch() { return _pitch; }
292
+ void set_pitch(size_t pitch) { _pitch = pitch; }
293
+
294
+ size_t get_x() { return _x; }
295
+ void set_x(size_t x) { _x = x; }
296
+
297
+ size_t get_y() { return _y; }
298
+ void set_y(size_t y) { _y = y; }
299
+
300
+ private:
301
+ void *_data;
302
+ size_t _pitch, _x, _y;
303
+ };
304
+
305
+ class device_info
306
+ {
307
+ public:
308
+ // get interface
309
+ const char *get_name() const { return _name; }
310
+ char *get_name() { return _name; }
311
+ template <typename WorkItemSizesTy = sycl::range<3>,
312
+ std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
313
+ std::is_same_v<WorkItemSizesTy, int *>,
314
+ int> = 0>
315
+ auto get_max_work_item_sizes() const
316
+ {
317
+ if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
318
+ return sycl::range<3>(_max_work_item_sizes_i[0],
319
+ _max_work_item_sizes_i[1],
320
+ _max_work_item_sizes_i[2]);
321
+ else
322
+ {
323
+ return _max_work_item_sizes_i;
324
+ }
325
+ }
326
+ template <typename WorkItemSizesTy = sycl::range<3>,
327
+ std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
328
+ std::is_same_v<WorkItemSizesTy, int *>,
329
+ int> = 0>
330
+ auto get_max_work_item_sizes()
331
+ {
332
+ if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
333
+ return sycl::range<3>(_max_work_item_sizes_i[0],
334
+ _max_work_item_sizes_i[1],
335
+ _max_work_item_sizes_i[2]);
336
+ else
337
+ {
338
+ return _max_work_item_sizes_i;
339
+ }
340
+ }
341
+ bool get_host_unified_memory() const { return _host_unified_memory; }
342
+ int get_major_version() const { return _major; }
343
+ int get_minor_version() const { return _minor; }
344
+ int get_integrated() const { return _integrated; }
345
+ int get_max_clock_frequency() const { return _frequency; }
346
+ int get_max_compute_units() const { return _max_compute_units; }
347
+ int get_max_work_group_size() const { return _max_work_group_size; }
348
+ int get_max_sub_group_size() const { return _max_sub_group_size; }
349
+ int get_max_work_items_per_compute_unit() const
350
+ {
351
+ return _max_work_items_per_compute_unit;
352
+ }
353
+ int get_max_register_size_per_work_group() const
354
+ {
355
+ return _max_register_size_per_work_group;
356
+ }
357
+ template <typename NDRangeSizeTy = size_t *,
358
+ std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
359
+ std::is_same_v<NDRangeSizeTy, int *>,
360
+ int> = 0>
361
+ auto get_max_nd_range_size() const
362
+ {
363
+ if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
364
+ return _max_nd_range_size;
365
+ else
366
+ return _max_nd_range_size_i;
367
+ }
368
+ template <typename NDRangeSizeTy = size_t *,
369
+ std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
370
+ std::is_same_v<NDRangeSizeTy, int *>,
371
+ int> = 0>
372
+ auto get_max_nd_range_size()
373
+ {
374
+ if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
375
+ return _max_nd_range_size;
376
+ else
377
+ return _max_nd_range_size_i;
378
+ }
379
+ size_t get_global_mem_size() const { return _global_mem_size; }
380
+ size_t get_local_mem_size() const { return _local_mem_size; }
381
+ size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; }
382
+ /// Returns the maximum clock rate of device's global memory in kHz. If
383
+ /// compiler does not support this API then returns default value 3200000 kHz.
384
+ unsigned int get_memory_clock_rate() const { return _memory_clock_rate; }
385
+ /// Returns the maximum bus width between device and memory in bits. If
386
+ /// compiler does not support this API then returns default value 64 bits.
387
+ unsigned int get_memory_bus_width() const { return _memory_bus_width; }
388
+ uint32_t get_device_id() const { return _device_id; }
389
+ std::array<unsigned char, 16> get_uuid() const { return _uuid; }
390
+ /// Returns global memory cache size in bytes.
391
+ unsigned int get_global_mem_cache_size() const
392
+ {
393
+ return _global_mem_cache_size;
394
+ }
395
+
396
+ // set interface
397
+ void set_name(const char *name)
398
+ {
399
+ size_t length = strlen(name);
400
+ if (length < 256)
401
+ {
402
+ std::memcpy(_name, name, length + 1);
403
+ }
404
+ else
405
+ {
406
+ std::memcpy(_name, name, 255);
407
+ _name[255] = '\0';
408
+ }
409
+ }
410
+ void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes)
411
+ {
412
+ for (int i = 0; i < 3; ++i)
413
+ _max_work_item_sizes_i[i] = max_work_item_sizes[i];
414
+ }
415
+ [[deprecated]] void
416
+ set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes)
417
+ {
418
+ for (int i = 0; i < 3; ++i)
419
+ {
420
+ _max_work_item_sizes_i[i] = max_work_item_sizes[i];
421
+ }
422
+ }
423
+ void set_host_unified_memory(bool host_unified_memory)
424
+ {
425
+ _host_unified_memory = host_unified_memory;
426
+ }
427
+ void set_major_version(int major) { _major = major; }
428
+ void set_minor_version(int minor) { _minor = minor; }
429
+ void set_integrated(int integrated) { _integrated = integrated; }
430
+ void set_max_clock_frequency(int frequency) { _frequency = frequency; }
431
+ void set_max_compute_units(int max_compute_units)
432
+ {
433
+ _max_compute_units = max_compute_units;
434
+ }
435
+ void set_global_mem_size(size_t global_mem_size)
436
+ {
437
+ _global_mem_size = global_mem_size;
438
+ }
439
+ void set_local_mem_size(size_t local_mem_size)
440
+ {
441
+ _local_mem_size = local_mem_size;
442
+ }
443
+ void set_max_mem_alloc_size(size_t max_mem_alloc_size)
444
+ {
445
+ _max_mem_alloc_size = max_mem_alloc_size;
446
+ }
447
+ void set_max_work_group_size(int max_work_group_size)
448
+ {
449
+ _max_work_group_size = max_work_group_size;
450
+ }
451
+ void set_max_sub_group_size(int max_sub_group_size)
452
+ {
453
+ _max_sub_group_size = max_sub_group_size;
454
+ }
455
+ void
456
+ set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit)
457
+ {
458
+ _max_work_items_per_compute_unit = max_work_items_per_compute_unit;
459
+ }
460
+ void set_max_nd_range_size(int max_nd_range_size[])
461
+ {
462
+ for (int i = 0; i < 3; i++)
463
+ {
464
+ _max_nd_range_size[i] = max_nd_range_size[i];
465
+ _max_nd_range_size_i[i] = max_nd_range_size[i];
466
+ }
467
+ }
468
+ void set_memory_clock_rate(unsigned int memory_clock_rate)
469
+ {
470
+ _memory_clock_rate = memory_clock_rate;
471
+ }
472
+ void set_memory_bus_width(unsigned int memory_bus_width)
473
+ {
474
+ _memory_bus_width = memory_bus_width;
475
+ }
476
+ void
477
+ set_max_register_size_per_work_group(int max_register_size_per_work_group)
478
+ {
479
+ _max_register_size_per_work_group = max_register_size_per_work_group;
480
+ }
481
+ void set_device_id(uint32_t device_id)
482
+ {
483
+ _device_id = device_id;
484
+ }
485
+ void set_uuid(std::array<unsigned char, 16> uuid)
486
+ {
487
+ _uuid = std::move(uuid);
488
+ }
489
+ void set_global_mem_cache_size(unsigned int global_mem_cache_size)
490
+ {
491
+ _global_mem_cache_size = global_mem_cache_size;
492
+ }
493
+
494
+ private:
495
+ char _name[256];
496
+ int _max_work_item_sizes_i[3];
497
+ bool _host_unified_memory = false;
498
+ int _major;
499
+ int _minor;
500
+ int _integrated = 0;
501
+ int _frequency;
502
+ // Set estimated value 3200000 kHz as default value.
503
+ unsigned int _memory_clock_rate = 3200000;
504
+ // Set estimated value 64 bits as default value.
505
+ unsigned int _memory_bus_width = 64;
506
+ unsigned int _global_mem_cache_size;
507
+ int _max_compute_units;
508
+ int _max_work_group_size;
509
+ int _max_sub_group_size;
510
+ int _max_work_items_per_compute_unit;
511
+ int _max_register_size_per_work_group;
512
+ size_t _global_mem_size;
513
+ size_t _local_mem_size;
514
+ size_t _max_mem_alloc_size;
515
+ size_t _max_nd_range_size[3];
516
+ int _max_nd_range_size_i[3];
517
+ uint32_t _device_id;
518
+ std::array<unsigned char, 16> _uuid;
519
+ };
520
+
521
+ static int get_major_version(const sycl::device &dev)
522
+ {
523
+ int major, minor;
524
+ detail::get_version(dev, major, minor);
525
+ return major;
526
+ }
527
+
528
+ static int get_minor_version(const sycl::device &dev)
529
+ {
530
+ int major, minor;
531
+ detail::get_version(dev, major, minor);
532
+ return minor;
533
+ }
534
+
535
+ static void get_device_info(device_info &out, const sycl::device &dev)
536
+ {
537
+ device_info prop;
538
+ prop.set_name(dev.get_info<sycl::info::device::name>().c_str());
539
+
540
+ int major, minor;
541
+ detail::get_version(dev, major, minor);
542
+ prop.set_major_version(major);
543
+ prop.set_minor_version(minor);
544
+
545
+ prop.set_max_work_item_sizes(
546
+ #if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902)
547
+ // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes
548
+ // is an enum class element
549
+ dev.get_info<sycl::info::device::max_work_item_sizes>());
550
+ #else
551
+ // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by
552
+ // an int
553
+ dev.get_info<sycl::info::device::max_work_item_sizes<3>>());
554
+ #endif
555
+ prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations));
556
+
557
+ prop.set_max_clock_frequency(
558
+ dev.get_info<sycl::info::device::max_clock_frequency>() * 1000);
559
+
560
+ prop.set_max_compute_units(
561
+ dev.get_info<sycl::info::device::max_compute_units>());
562
+ prop.set_max_work_group_size(
563
+ dev.get_info<sycl::info::device::max_work_group_size>());
564
+ prop.set_global_mem_size(dev.get_info<sycl::info::device::global_mem_size>());
565
+ prop.set_local_mem_size(dev.get_info<sycl::info::device::local_mem_size>());
566
+ prop.set_max_mem_alloc_size(dev.get_info<sycl::info::device::max_mem_alloc_size>());
567
+
568
+ #if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6)
569
+ if (dev.has(sycl::aspect::ext_intel_memory_clock_rate))
570
+ {
571
+ unsigned int tmp =
572
+ dev.get_info<sycl::ext::intel::info::device::memory_clock_rate>();
573
+ if (tmp != 0)
574
+ prop.set_memory_clock_rate(1000 * tmp);
575
+ }
576
+ if (dev.has(sycl::aspect::ext_intel_memory_bus_width))
577
+ {
578
+ prop.set_memory_bus_width(
579
+ dev.get_info<sycl::ext::intel::info::device::memory_bus_width>());
580
+ }
581
+ if (dev.has(sycl::aspect::ext_intel_device_id))
582
+ {
583
+ prop.set_device_id(
584
+ dev.get_info<sycl::ext::intel::info::device::device_id>());
585
+ }
586
+ if (dev.has(sycl::aspect::ext_intel_device_info_uuid))
587
+ {
588
+ prop.set_uuid(dev.get_info<sycl::ext::intel::info::device::uuid>());
589
+ }
590
+ #elif defined(_MSC_VER) && !defined(__clang__)
591
+ #pragma message("get_device_info: querying memory_clock_rate and \
592
+ memory_bus_width are not supported by the compiler used. \
593
+ Use 3200000 kHz as memory_clock_rate default value. \
594
+ Use 64 bits as memory_bus_width default value.")
595
+ #else
596
+ #warning "get_device_info: querying memory_clock_rate and \
597
+ memory_bus_width are not supported by the compiler used. \
598
+ Use 3200000 kHz as memory_clock_rate default value. \
599
+ Use 64 bits as memory_bus_width default value."
600
+ #endif
601
+
602
+ size_t max_sub_group_size = 1;
603
+ std::vector<size_t> sub_group_sizes =
604
+ dev.get_info<sycl::info::device::sub_group_sizes>();
605
+
606
+ for (const auto &sub_group_size : sub_group_sizes)
607
+ {
608
+ if (max_sub_group_size < sub_group_size)
609
+ max_sub_group_size = sub_group_size;
610
+ }
611
+
612
+ prop.set_max_sub_group_size(max_sub_group_size);
613
+
614
+ prop.set_max_work_items_per_compute_unit(
615
+ dev.get_info<sycl::info::device::max_work_group_size>());
616
+ int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF};
617
+ prop.set_max_nd_range_size(max_nd_range_size);
618
+
619
+ // Estimates max register size per work group, feel free to update the value
620
+ // according to device properties.
621
+ prop.set_max_register_size_per_work_group(65536);
622
+
623
+ prop.set_global_mem_cache_size(
624
+ dev.get_info<sycl::info::device::global_mem_cache_size>());
625
+ out = prop;
626
+ }
627
+
628
+ /// dpct device extension
629
+ class device_ext : public sycl::device {
630
+ typedef std::mutex mutex_type;
631
+
632
+ public:
633
+ device_ext() : sycl::device() {}
634
+ ~device_ext() {
635
+ std::lock_guard<mutex_type> lock(m_mutex);
636
+ clear_queues();
637
+ }
638
+ device_ext(const sycl::device &base) : sycl::device(base) {
639
+ std::lock_guard<mutex_type> lock(m_mutex);
640
+ init_queues();
641
+ }
642
+
643
+ int is_native_atomic_supported() { return 0; }
644
+ int get_major_version() const { return dpct::get_major_version(*this); }
645
+
646
+ int get_minor_version() const { return dpct::get_minor_version(*this); }
647
+
648
+ int get_max_compute_units() const {
649
+ return get_device_info().get_max_compute_units();
650
+ }
651
+
652
+ /// Return the maximum clock frequency of this device in KHz.
653
+ int get_max_clock_frequency() const {
654
+ return get_device_info().get_max_clock_frequency();
655
+ }
656
+
657
+ int get_integrated() const { return get_device_info().get_integrated(); }
658
+
659
+ int get_max_sub_group_size() const {
660
+ return get_device_info().get_max_sub_group_size();
661
+ }
662
+
663
+ int get_max_register_size_per_work_group() const {
664
+ return get_device_info().get_max_register_size_per_work_group();
665
+ }
666
+
667
+ int get_max_work_group_size() const {
668
+ return get_device_info().get_max_work_group_size();
669
+ }
670
+
671
+ int get_mem_base_addr_align() const {
672
+ return get_info<sycl::info::device::mem_base_addr_align>();
673
+ }
674
+
675
+ size_t get_global_mem_size() const {
676
+ return get_device_info().get_global_mem_size();
677
+ }
678
+
679
+ size_t get_max_mem_alloc_size() const {
680
+ return get_device_info().get_max_mem_alloc_size();
681
+ }
682
+
683
+ /// Get the number of bytes of free and total memory on the SYCL device.
684
+ /// \param [out] free_memory The number of bytes of free memory on the
685
+ /// SYCL device. \param [out] total_memory The number of bytes of total
686
+ /// memory on the SYCL device.
687
+ void get_memory_info(size_t &free_memory, size_t &total_memory) {
688
+ total_memory = get_device_info().get_global_mem_size();
689
+ const char *warning_info =
690
+ "get_memory_info: [warning] ext_intel_free_memory is not "
691
+ "supported (export/set ZES_ENABLE_SYSMAN=1 to support), "
692
+ "use total memory as free memory";
693
+ #if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105)
694
+ if (!has(sycl::aspect::ext_intel_free_memory)) {
695
+ std::cerr << warning_info << std::endl;
696
+ free_memory = total_memory;
697
+ } else {
698
+ free_memory = get_info<sycl::ext::intel::info::device::free_memory>();
699
+ }
700
+ #else
701
+ std::cerr << warning_info << std::endl;
702
+ free_memory = total_memory;
703
+ #if defined(_MSC_VER) && !defined(__clang__)
704
+ #pragma message("Querying the number of bytes of free memory is not supported")
705
+ #else
706
+ #warning "Querying the number of bytes of free memory is not supported"
707
+ #endif
708
+ #endif
709
+ }
710
+
711
+ void get_device_info(device_info &out) const {
712
+ dpct::get_device_info(out, *this);
713
+ }
714
+
715
+ device_info get_device_info() const {
716
+ device_info prop;
717
+ dpct::get_device_info(prop, *this);
718
+ return prop;
719
+ }
720
+
721
+ void reset() {
722
+ std::lock_guard<mutex_type> lock(m_mutex);
723
+ clear_queues();
724
+ init_queues();
725
+ }
726
+
727
+ sycl::queue &in_order_queue() { return _q_in_order; }
728
+
729
+ sycl::queue &out_of_order_queue() { return _q_out_of_order; }
730
+
731
+ sycl::queue &default_queue() { return in_order_queue(); }
732
+
733
+ void queues_wait_and_throw() {
734
+ std::unique_lock<mutex_type> lock(m_mutex);
735
+ lock.unlock();
736
+ for (auto &q : _queues) {
737
+ q.wait_and_throw();
738
+ }
739
+ // Guard the destruct of current_queues to make sure the ref count is
740
+ // safe.
741
+ lock.lock();
742
+ }
743
+
744
+ sycl::queue create_queue(bool enable_exception_handler = false) {
745
+ return create_in_order_queue(enable_exception_handler);
746
+ }
747
+
748
+ sycl::queue create_queue(sycl::device device,
749
+ bool enable_exception_handler = false) {
750
+ return create_in_order_queue(device, enable_exception_handler);
751
+ }
752
+
753
+ sycl::queue create_in_order_queue(bool enable_exception_handler = false) {
754
+ std::lock_guard<mutex_type> lock(m_mutex);
755
+ return create_queue_impl(enable_exception_handler,
756
+ sycl::property::queue::in_order());
757
+ }
758
+
759
+ sycl::queue create_in_order_queue(sycl::device device,
760
+ bool enable_exception_handler = false) {
761
+ std::lock_guard<mutex_type> lock(m_mutex);
762
+ return create_queue_impl(device, enable_exception_handler,
763
+ sycl::property::queue::in_order());
764
+ }
765
+
766
+ sycl::queue create_out_of_order_queue(
767
+ bool enable_exception_handler = false) {
768
+ std::lock_guard<mutex_type> lock(m_mutex);
769
+ return create_queue_impl(enable_exception_handler);
770
+ }
771
+
772
+ void destroy_queue(sycl::queue queue) {
773
+ std::lock_guard<mutex_type> lock(m_mutex);
774
+ _queues.erase(std::remove_if(_queues.begin(), _queues.end(),
775
+ [=](const sycl::queue &q) -> bool
776
+ {
777
+ return q == queue;
778
+ }),
779
+ _queues.end());
780
+ }
781
+ void set_saved_queue(sycl::queue q) {
782
+ std::lock_guard<mutex_type> lock(m_mutex);
783
+ _saved_queue = q;
784
+ }
785
+ sycl::queue get_saved_queue() const {
786
+ std::lock_guard<mutex_type> lock(m_mutex);
787
+ return _saved_queue;
788
+ }
789
+
790
+ private:
791
+ void clear_queues() { _queues.clear(); }
792
+
793
+ void init_queues() {
794
+ _q_in_order =
795
+ create_queue_impl(true, sycl::property::queue::in_order());
796
+ _q_out_of_order = create_queue_impl(true);
797
+ _saved_queue = default_queue();
798
+ }
799
+
800
+ /// Caller should acquire resource \p m_mutex before calling this
801
+ /// function.
802
+ template <class... Properties>
803
+ sycl::queue create_queue_impl(bool enable_exception_handler,
804
+ Properties... properties) {
805
+ sycl::async_handler eh = {};
806
+ if (enable_exception_handler) {
807
+ eh = exception_handler;
808
+ }
809
+ _queues.push_back(sycl::queue(
810
+ *this, eh,
811
+ sycl::property_list(
812
+ #ifdef DPCT_PROFILING_ENABLED
813
+ sycl::property::queue::enable_profiling(),
814
+ #endif
815
+ properties...)));
816
+
817
+ return _queues.back();
818
+ }
819
+
820
+ template <class... Properties>
821
+ sycl::queue create_queue_impl(sycl::device device,
822
+ bool enable_exception_handler,
823
+ Properties... properties) {
824
+ sycl::async_handler eh = {};
825
+ if (enable_exception_handler) {
826
+ eh = exception_handler;
827
+ }
828
+ _queues.push_back(sycl::queue(
829
+ device, eh,
830
+ sycl::property_list(
831
+ #ifdef DPCT_PROFILING_ENABLED
832
+ sycl::property::queue::enable_profiling(),
833
+ #endif
834
+ properties...)));
835
+
836
+ return _queues.back();
837
+ }
838
+
839
+ void get_version(int &major, int &minor) const {
840
+ detail::get_version(*this, major, minor);
841
+ }
842
+ sycl::queue _q_in_order, _q_out_of_order;
843
+ sycl::queue _saved_queue;
844
+ std::vector<sycl::queue> _queues;
845
+ mutable mutex_type m_mutex;
846
+ };
847
+
848
+
849
+ /// device manager
850
+ class dev_mgr
851
+ {
852
+ public:
853
+ device_ext &current_device()
854
+ {
855
+ unsigned int dev_id = current_device_id();
856
+ check_id(dev_id);
857
+ return *_devs[dev_id];
858
+ }
859
+ device_ext &cpu_device() const
860
+ {
861
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
862
+ if (_cpu_device == -1)
863
+ {
864
+ throw std::runtime_error("no valid cpu device");
865
+ }
866
+ else
867
+ {
868
+ return *_devs[_cpu_device];
869
+ }
870
+ }
871
+ device_ext &get_device(unsigned int id) const
872
+ {
873
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
874
+ check_id(id);
875
+ return *_devs[id];
876
+ }
877
+ unsigned int current_device_id() const
878
+ {
879
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
880
+ auto it = _thread2dev_map.find(get_tid());
881
+ if (it != _thread2dev_map.end())
882
+ return it->second;
883
+ return DEFAULT_DEVICE_ID;
884
+ }
885
+
886
+ /// Select device with a device ID.
887
+ /// \param [in] id The id of the device which can
888
+ /// be obtained through get_device_id(const sycl::device).
889
+ void select_device(unsigned int id)
890
+ {
891
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
892
+ check_id(id);
893
+ _thread2dev_map[get_tid()] = id;
894
+ }
895
+ unsigned int device_count() { return _devs.size(); }
896
+
897
+ unsigned int get_device_id(const sycl::device &dev)
898
+ {
899
+ unsigned int id = 0;
900
+ for (auto &dev_item : _devs)
901
+ {
902
+ if (*dev_item == dev)
903
+ {
904
+ return id;
905
+ }
906
+ id++;
907
+ }
908
+ return -1;
909
+ }
910
+
911
+ inline std::string get_preferred_gpu_platform_name() {
912
+ std::string result;
913
+
914
+ std::string filter = "";
915
+ char* env = getenv("ONEAPI_DEVICE_SELECTOR");
916
+ if (env) {
917
+ if (std::strstr(env, "level_zero")) {
918
+ filter = "level-zero";
919
+ }
920
+ else if (std::strstr(env, "opencl")) {
921
+ filter = "opencl";
922
+ }
923
+ else if (std::strstr(env, "cuda")) {
924
+ filter = "cuda";
925
+ }
926
+ else if (std::strstr(env, "hip")) {
927
+ filter = "hip";
928
+ }
929
+ else {
930
+ throw std::runtime_error("invalid device filter: " + std::string(env));
931
+ }
932
+ } else {
933
+ auto default_device = sycl::device(sycl::default_selector_v);
934
+ auto default_platform_name = default_device.get_platform().get_info<sycl::info::platform::name>();
935
+
936
+ if (std::strstr(default_platform_name.c_str(), "Level-Zero") || default_device.is_cpu()) {
937
+ filter = "level-zero";
938
+ }
939
+ else if (std::strstr(default_platform_name.c_str(), "CUDA")) {
940
+ filter = "cuda";
941
+ }
942
+ else if (std::strstr(default_platform_name.c_str(), "HIP")) {
943
+ filter = "hip";
944
+ }
945
+ }
946
+
947
+ auto platform_list = sycl::platform::get_platforms();
948
+
949
+ for (const auto& platform : platform_list) {
950
+ auto devices = platform.get_devices();
951
+ auto gpu_dev = std::find_if(devices.begin(), devices.end(), [](const sycl::device& d) {
952
+ return d.is_gpu();
953
+ });
954
+
955
+ if (gpu_dev == devices.end()) {
956
+ // cout << "platform [" << platform_name
957
+ // << "] does not contain GPU devices, skipping\n";
958
+ continue;
959
+ }
960
+
961
+ auto platform_name = platform.get_info<sycl::info::platform::name>();
962
+ std::string platform_name_low_case;
963
+ platform_name_low_case.resize(platform_name.size());
964
+
965
+ std::transform(
966
+ platform_name.begin(), platform_name.end(), platform_name_low_case.begin(), ::tolower);
967
+
968
+ if (platform_name_low_case.find(filter) == std::string::npos) {
969
+ // cout << "platform [" << platform_name
970
+ // << "] does not match with requested "
971
+ // << filter << ", skipping\n";
972
+ continue;
973
+ }
974
+
975
+ result = platform_name;
976
+ }
977
+
978
+ if (result.empty())
979
+ throw std::runtime_error("can not find preferred GPU platform");
980
+
981
+ return result;
982
+ }
983
+
984
+ template <class DeviceSelector>
985
+ std::enable_if_t<
986
+ std::is_invocable_r_v<int, DeviceSelector, const sycl::device &>>
987
+ select_device(const DeviceSelector &selector = sycl::gpu_selector_v)
988
+ {
989
+ sycl::device selected_device = sycl::device(selector);
990
+ unsigned int selected_device_id = get_device_id(selected_device);
991
+ select_device(selected_device_id);
992
+ }
993
+
994
+ /// Returns the instance of device manager singleton.
995
+ static dev_mgr &instance()
996
+ {
997
+ static dev_mgr d_m;
998
+ return d_m;
999
+ }
1000
+ dev_mgr(const dev_mgr &) = delete;
1001
+ dev_mgr &operator=(const dev_mgr &) = delete;
1002
+ dev_mgr(dev_mgr &&) = delete;
1003
+ dev_mgr &operator=(dev_mgr &&) = delete;
1004
+
1005
+ private:
1006
+ mutable std::recursive_mutex m_mutex;
1007
+ static bool compare_dev(sycl::device &device1, sycl::device &device2)
1008
+ {
1009
+ sycl::backend backend1 = device1.get_backend();
1010
+ sycl::backend backend2 = device2.get_backend();
1011
+ // levelzero backends always come first
1012
+ if(backend1 == sycl::backend::ext_oneapi_level_zero && backend2 != sycl::backend::ext_oneapi_level_zero) return true;
1013
+ if(backend1 != sycl::backend::ext_oneapi_level_zero && backend2 == sycl::backend::ext_oneapi_level_zero) return false;
1014
+ dpct::device_info prop1;
1015
+ dpct::get_device_info(prop1, device1);
1016
+ dpct::device_info prop2;
1017
+ dpct::get_device_info(prop2, device2);
1018
+ return prop1.get_max_compute_units() > prop2.get_max_compute_units();
1019
+ }
1020
+ static int convert_backend_index(std::string & backend) {
1021
+ if (backend == "ext_oneapi_level_zero:gpu") return 0;
1022
+ if (backend == "opencl:gpu") return 1;
1023
+ if (backend == "ext_oneapi_cuda:gpu") return 2;
1024
+ if (backend == "ext_oneapi_hip:gpu") return 3;
1025
+ if (backend == "opencl:cpu") return 4;
1026
+ if (backend == "opencl:acc") return 5;
1027
+ printf("convert_backend_index: can't handle backend=%s\n", backend.c_str());
1028
+ GGML_ABORT("fatal error");
1029
+ }
1030
+ static bool compare_backend(std::string &backend1, std::string &backend2) {
1031
+ return convert_backend_index(backend1) < convert_backend_index(backend2);
1032
+ }
1033
+ dev_mgr()
1034
+ {
1035
+ sycl::device default_device =
1036
+ sycl::device(sycl::default_selector_v);
1037
+ _devs.push_back(std::make_shared<device_ext>(default_device));
1038
+
1039
+ std::vector<sycl::device> sycl_all_devs;
1040
+ // Collect other devices except for the default device.
1041
+ if (default_device.is_cpu())
1042
+ _cpu_device = 0;
1043
+
1044
+ auto Platforms = sycl::platform::get_platforms();
1045
+ // Keep track of the number of devices per backend
1046
+ std::map<sycl::backend, size_t> DeviceNums;
1047
+ std::map<std::string, std::vector<sycl::device>> backend_devices;
1048
+ auto preferred_platform_name = get_preferred_gpu_platform_name();
1049
+
1050
+ while (!Platforms.empty()) {
1051
+ auto Platform = Platforms.back();
1052
+ Platforms.pop_back();
1053
+ auto platform_name = Platform.get_info<sycl::info::platform::name>();
1054
+ if (platform_name.compare(preferred_platform_name) != 0) {
1055
+ continue;
1056
+ }
1057
+ auto devices = Platform.get_devices();
1058
+ std::string backend_type = get_device_backend_and_type(devices[0]);
1059
+ for (const auto &device : devices) {
1060
+ backend_devices[backend_type].push_back(device);
1061
+ }
1062
+ }
1063
+
1064
+ std::vector<std::string> keys;
1065
+ for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) {
1066
+ keys.push_back(it->first);
1067
+ }
1068
+ std::sort(keys.begin(), keys.end(), compare_backend);
1069
+
1070
+ for (auto &key : keys) {
1071
+ std::vector<sycl::device> devs = backend_devices[key];
1072
+ std::sort(devs.begin(), devs.end(), compare_dev);
1073
+ for (const auto &dev : devs) {
1074
+ sycl_all_devs.push_back(dev);
1075
+ }
1076
+ }
1077
+
1078
+ for (auto &dev : sycl_all_devs)
1079
+ {
1080
+ if (dev == default_device)
1081
+ {
1082
+ continue;
1083
+ }
1084
+ _devs.push_back(std::make_shared<device_ext>(dev));
1085
+ if (_cpu_device == -1 && dev.is_cpu())
1086
+ {
1087
+ _cpu_device = _devs.size() - 1;
1088
+ }
1089
+ }
1090
+ }
1091
+ void check_id(unsigned int id) const
1092
+ {
1093
+ if (id >= _devs.size())
1094
+ {
1095
+ throw std::runtime_error("invalid device id");
1096
+ }
1097
+ }
1098
+ std::vector<std::shared_ptr<device_ext>> _devs;
1099
+ /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current
1100
+ /// thread id in _thread2dev_map, which means default device should be used
1101
+ /// for the current thread.
1102
+ const unsigned int DEFAULT_DEVICE_ID = 0;
1103
+ /// thread-id to device-id map.
1104
+ std::map<unsigned int, unsigned int> _thread2dev_map;
1105
+ int _cpu_device = -1;
1106
+ };
1107
+
1108
+ static inline sycl::queue &get_default_queue()
1109
+ {
1110
+ return dev_mgr::instance().current_device().default_queue();
1111
+ }
1112
+
1113
+ namespace detail
1114
+ {
1115
+ enum class pointer_access_attribute
1116
+ {
1117
+ host_only = 0,
1118
+ device_only,
1119
+ host_device,
1120
+ end
1121
+ };
1122
+
1123
+ static pointer_access_attribute get_pointer_attribute(sycl::queue &q,
1124
+ const void *ptr)
1125
+ {
1126
+ switch (sycl::get_pointer_type(ptr, q.get_context()))
1127
+ {
1128
+ case sycl::usm::alloc::unknown:
1129
+ return pointer_access_attribute::host_only;
1130
+ case sycl::usm::alloc::device:
1131
+ return pointer_access_attribute::device_only;
1132
+ case sycl::usm::alloc::shared:
1133
+ case sycl::usm::alloc::host:
1134
+ return pointer_access_attribute::host_device;
1135
+ }
1136
+ }
1137
+
1138
+ template <typename ArgT>
1139
+ inline constexpr std::uint64_t get_type_combination_id(ArgT Val)
1140
+ {
1141
+ static_assert((unsigned char)library_data_t::library_data_t_size <=
1142
+ std::numeric_limits<unsigned char>::max() &&
1143
+ "library_data_t size exceeds limit.");
1144
+ static_assert(std::is_same_v<ArgT, library_data_t>, "Unsupported ArgT");
1145
+ return (std::uint64_t)Val;
1146
+ }
1147
+
1148
+ template <typename FirstT, typename... RestT>
1149
+ inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal,
1150
+ RestT... RestVal)
1151
+ {
1152
+ static_assert((std::uint8_t)library_data_t::library_data_t_size <=
1153
+ std::numeric_limits<unsigned char>::max() &&
1154
+ "library_data_t size exceeds limit.");
1155
+ static_assert(sizeof...(RestT) <= 8 && "Too many parameters");
1156
+ static_assert(std::is_same_v<FirstT, library_data_t>, "Unsupported FirstT");
1157
+ return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal);
1158
+ }
1159
+
1160
+ class mem_mgr
1161
+ {
1162
+ mem_mgr()
1163
+ {
1164
+ // Reserved address space, no real memory allocation happens here.
1165
+ #if defined(__linux__)
1166
+ mapped_address_space =
1167
+ (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE,
1168
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
1169
+ #elif defined(_WIN64)
1170
+ mapped_address_space = (byte_t *)VirtualAlloc(
1171
+ NULL, // NULL specified as the base address parameter
1172
+ mapped_region_size, // Size of allocation
1173
+ MEM_RESERVE, // Allocate reserved pages
1174
+ PAGE_NOACCESS); // Protection = no access
1175
+ #else
1176
+ #error "Only support Windows and Linux."
1177
+ #endif
1178
+ next_free = mapped_address_space;
1179
+ }
1180
+
1181
+ public:
1182
+ using buffer_id_t = int;
1183
+
1184
+ struct allocation
1185
+ {
1186
+ buffer_t buffer;
1187
+ byte_t *alloc_ptr;
1188
+ size_t size;
1189
+ };
1190
+
1191
+ ~mem_mgr()
1192
+ {
1193
+ #if defined(__linux__)
1194
+ munmap(mapped_address_space, mapped_region_size);
1195
+ #elif defined(_WIN64)
1196
+ VirtualFree(mapped_address_space, 0, MEM_RELEASE);
1197
+ #else
1198
+ #error "Only support Windows and Linux."
1199
+ #endif
1200
+ }
1201
+
1202
+ mem_mgr(const mem_mgr &) = delete;
1203
+ mem_mgr &operator=(const mem_mgr &) = delete;
1204
+ mem_mgr(mem_mgr &&) = delete;
1205
+ mem_mgr &operator=(mem_mgr &&) = delete;
1206
+
1207
+ /// Allocate
1208
+ void *mem_alloc(size_t size)
1209
+ {
1210
+ if (!size)
1211
+ return nullptr;
1212
+ std::lock_guard<std::mutex> lock(m_mutex);
1213
+ if (next_free + size > mapped_address_space + mapped_region_size)
1214
+ {
1215
+ throw std::runtime_error("dpct_malloc: out of memory for virtual memory pool");
1216
+ }
1217
+ // Allocation
1218
+ sycl::range<1> r(size);
1219
+ buffer_t buf(r);
1220
+ allocation A{buf, next_free, size};
1221
+ // Map allocation to device pointer
1222
+ void *result = next_free;
1223
+ m_map.emplace(next_free + size, A);
1224
+ // Update pointer to the next free space.
1225
+ next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1);
1226
+
1227
+ return result;
1228
+ }
1229
+
1230
+ /// Deallocate
1231
+ void mem_free(const void *ptr)
1232
+ {
1233
+ if (!ptr)
1234
+ return;
1235
+ std::lock_guard<std::mutex> lock(m_mutex);
1236
+ auto it = get_map_iterator(ptr);
1237
+ m_map.erase(it);
1238
+ }
1239
+
1240
+ /// map: device pointer -> allocation(buffer, alloc_ptr, size)
1241
+ allocation translate_ptr(const void *ptr)
1242
+ {
1243
+ std::lock_guard<std::mutex> lock(m_mutex);
1244
+ auto it = get_map_iterator(ptr);
1245
+ return it->second;
1246
+ }
1247
+
1248
+ /// Check if the pointer represents device pointer or not.
1249
+ bool is_device_ptr(const void *ptr) const
1250
+ {
1251
+ std::lock_guard<std::mutex> lock(m_mutex);
1252
+ return (mapped_address_space <= ptr) &&
1253
+ (ptr < mapped_address_space + mapped_region_size);
1254
+ }
1255
+
1256
+ /// Returns the instance of memory manager singleton.
1257
+ static mem_mgr &instance()
1258
+ {
1259
+ static mem_mgr m;
1260
+ return m;
1261
+ }
1262
+
1263
+ private:
1264
+ std::map<byte_t *, allocation> m_map;
1265
+ mutable std::mutex m_mutex;
1266
+ byte_t *mapped_address_space;
1267
+ byte_t *next_free;
1268
+ const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024;
1269
+ const size_t alignment = 256;
1270
+ /// This padding may be defined to some positive value to debug
1271
+ /// out of bound accesses.
1272
+ const size_t extra_padding = 0;
1273
+
1274
+ std::map<byte_t *, allocation>::iterator get_map_iterator(const void *ptr)
1275
+ {
1276
+ auto it = m_map.upper_bound(const_cast<byte_t *>(reinterpret_cast<const byte_t *>(ptr)));
1277
+ if (it == m_map.end())
1278
+ {
1279
+ // Not a virtual pointer.
1280
+ throw std::runtime_error("can not get buffer from non-virtual pointer");
1281
+ }
1282
+ const allocation &alloc = it->second;
1283
+ if (ptr < alloc.alloc_ptr)
1284
+ {
1285
+ // Out of bound.
1286
+ // This may happen if there's a gap between allocations due to alignment
1287
+ // or extra padding and pointer points to this gap.
1288
+ throw std::runtime_error("invalid virtual pointer");
1289
+ }
1290
+ return it;
1291
+ }
1292
+ };
1293
+
1294
+ template <class T, memory_region Memory, size_t Dimension>
1295
+ class accessor;
1296
+ template <memory_region Memory, class T = byte_t>
1297
+ class memory_traits
1298
+ {
1299
+ public:
1300
+ static constexpr sycl::access::target target =
1301
+ sycl::access::target::device;
1302
+ static constexpr sycl::access_mode mode =
1303
+ (Memory == constant) ? sycl::access_mode::read
1304
+ : sycl::access_mode::read_write;
1305
+ static constexpr size_t type_size = sizeof(T);
1306
+ using element_t =
1307
+ typename std::conditional<Memory == constant, const T, T>::type;
1308
+ using value_t = typename std::remove_cv<T>::type;
1309
+ template <size_t Dimension = 1>
1310
+ using accessor_t = typename std::conditional<
1311
+ Memory == local, sycl::local_accessor<value_t, Dimension>,
1312
+ sycl::accessor<T, Dimension, mode, target>>::type;
1313
+ using pointer_t = T *;
1314
+ };
1315
+
1316
+ static inline void *dpct_malloc(size_t size, sycl::queue &q)
1317
+ {
1318
+ return sycl::malloc_device(size, q.get_device(), q.get_context());
1319
+ }
1320
+
1321
+ #define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F))
1322
+ static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z,
1323
+ sycl::queue &q)
1324
+ {
1325
+ pitch = PITCH_DEFAULT_ALIGN(x);
1326
+ return dpct_malloc(pitch * y * z, q);
1327
+ }
1328
+
1329
+ /**
1330
+ * @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q.
1331
+ * @tparam valueT The type of the element to be set.
1332
+ * @param [in] q The queue in which the operation is done.
1333
+ * @param [in] dev_ptr Pointer to the virtual device memory address.
1334
+ * @param [in] value The value to be set.
1335
+ * @param [in] size Number of elements to be set to the value.
1336
+ * @return An event representing the memset operation.
1337
+ */
1338
+ template <typename valueT>
1339
+ static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr,
1340
+ valueT value, size_t size)
1341
+ {
1342
+ return q.fill(dev_ptr, value, size);
1343
+ }
1344
+
1345
+ /**
1346
+ * @brief Sets \p value to the 3D memory region pointed by \p data in \p q.
1347
+ * @tparam valueT The type of the element to be set.
1348
+ * @param [in] q The queue in which the operation is done.
1349
+ * @param [in] data Pointer to the pitched device memory region.
1350
+ * @param [in] value The value to be set.
1351
+ * @param [in] size 3D memory region by number of elements.
1352
+ * @return An event list representing the memset operations.
1353
+ */
1354
+ template <typename valueT>
1355
+ static inline std::vector<sycl::event>
1356
+ dpct_memset(sycl::queue &q, pitched_data data, valueT value,
1357
+ sycl::range<3> size)
1358
+ {
1359
+ std::vector<sycl::event> event_list;
1360
+ size_t slice = data.get_pitch() * data.get_y();
1361
+ unsigned char *data_surface = (unsigned char *)data.get_data_ptr();
1362
+ for (size_t z = 0; z < size.get(2); ++z)
1363
+ {
1364
+ unsigned char *data_ptr = data_surface;
1365
+ for (size_t y = 0; y < size.get(1); ++y)
1366
+ {
1367
+ event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0)));
1368
+ data_ptr += data.get_pitch();
1369
+ }
1370
+ data_surface += slice;
1371
+ }
1372
+ return event_list;
1373
+ }
1374
+
1375
+ /**
1376
+ * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q.
1377
+ * @tparam valueT The type of the element to be set.
1378
+ * @param [in] q The queue in which the operation is done.
1379
+ * @param [in] ptr Pointer to the virtual device memory.
1380
+ * @param [in] pitch The pitch size by number of elements, including padding.
1381
+ * @param [in] val The value to be set.
1382
+ * @param [in] x The width of memory region by number of elements.
1383
+ * @param [in] y The height of memory region by number of elements.
1384
+ * @return An event list representing the memset operations.
1385
+ */
1386
+ template <typename valueT>
1387
+ static inline std::vector<sycl::event>
1388
+ dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x,
1389
+ size_t y)
1390
+ {
1391
+ return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val,
1392
+ sycl::range<3>(x, y, 1));
1393
+ }
1394
+
1395
+ static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr,
1396
+ const void *from_ptr,
1397
+ memcpy_direction dir)
1398
+ {
1399
+ switch (dir)
1400
+ {
1401
+ case memcpy_direction::host_to_host:
1402
+ case memcpy_direction::host_to_device:
1403
+ case memcpy_direction::device_to_host:
1404
+ case memcpy_direction::device_to_device:
1405
+ return dir;
1406
+ case memcpy_direction::automatic:
1407
+ {
1408
+ // table[to_attribute][from_attribute]
1409
+ static const memcpy_direction
1410
+ direction_table[static_cast<unsigned>(pointer_access_attribute::end)]
1411
+ [static_cast<unsigned>(pointer_access_attribute::end)] =
1412
+ {{memcpy_direction::host_to_host,
1413
+ memcpy_direction::device_to_host,
1414
+ memcpy_direction::host_to_host},
1415
+ {memcpy_direction::host_to_device,
1416
+ memcpy_direction::device_to_device,
1417
+ memcpy_direction::device_to_device},
1418
+ {memcpy_direction::host_to_host,
1419
+ memcpy_direction::device_to_device,
1420
+ memcpy_direction::device_to_device}};
1421
+ return direction_table[static_cast<unsigned>(get_pointer_attribute(
1422
+ q, to_ptr))][static_cast<unsigned>(get_pointer_attribute(q, from_ptr))];
1423
+ }
1424
+ default:
1425
+ throw std::runtime_error("dpct_memcpy: invalid direction value");
1426
+ }
1427
+ }
1428
+
1429
+ static sycl::event
1430
+ dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,
1431
+ memcpy_direction direction,
1432
+ const std::vector<sycl::event> &dep_events = {})
1433
+ {
1434
+ if (!size)
1435
+ return sycl::event{};
1436
+ return q.memcpy(to_ptr, from_ptr, size, dep_events);
1437
+ GGML_UNUSED(direction);
1438
+ }
1439
+
1440
+ // Get actual copy range and make sure it will not exceed range.
1441
+ static inline size_t get_copy_range(sycl::range<3> size, size_t slice,
1442
+ size_t pitch)
1443
+ {
1444
+ return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);
1445
+ }
1446
+
1447
+ static inline size_t get_offset(sycl::id<3> id, size_t slice,
1448
+ size_t pitch)
1449
+ {
1450
+ return slice * id.get(2) + pitch * id.get(1) + id.get(0);
1451
+ }
1452
+
1453
+ /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr
1454
+ /// and \p from_range to another specified by \p to_ptr and \p to_range.
1455
+ static inline std::vector<sycl::event>
1456
+ dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
1457
+ sycl::range<3> to_range, sycl::range<3> from_range,
1458
+ sycl::id<3> to_id, sycl::id<3> from_id,
1459
+ sycl::range<3> size, memcpy_direction direction,
1460
+ const std::vector<sycl::event> &dep_events = {})
1461
+ {
1462
+ // RAII for host pointer
1463
+ class host_buffer
1464
+ {
1465
+ void *_buf;
1466
+ size_t _size;
1467
+ sycl::queue &_q;
1468
+ const std::vector<sycl::event> &_deps; // free operation depends
1469
+
1470
+ public:
1471
+ host_buffer(size_t size, sycl::queue &q,
1472
+ const std::vector<sycl::event> &deps)
1473
+ : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}
1474
+ void *get_ptr() const { return _buf; }
1475
+ size_t get_size() const { return _size; }
1476
+ ~host_buffer()
1477
+ {
1478
+ if (_buf)
1479
+ {
1480
+ _q.submit([&](sycl::handler &cgh)
1481
+ {
1482
+ cgh.depends_on(_deps);
1483
+ cgh.host_task([buf = _buf] { std::free(buf); }); });
1484
+ }
1485
+ }
1486
+ };
1487
+ std::vector<sycl::event> event_list;
1488
+
1489
+ size_t to_slice = to_range.get(1) * to_range.get(0),
1490
+ from_slice = from_range.get(1) * from_range.get(0);
1491
+ unsigned char *to_surface =
1492
+ (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));
1493
+ const unsigned char *from_surface =
1494
+ (const unsigned char *)from_ptr +
1495
+ get_offset(from_id, from_slice, from_range.get(0));
1496
+
1497
+ if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))
1498
+ {
1499
+ return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),
1500
+ direction, dep_events)};
1501
+ }
1502
+ direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
1503
+ size_t size_slice = size.get(1) * size.get(0);
1504
+ switch (direction)
1505
+ {
1506
+ case host_to_host:
1507
+ for (size_t z = 0; z < size.get(2); ++z)
1508
+ {
1509
+ unsigned char *to_ptr = to_surface;
1510
+ const unsigned char *from_ptr = from_surface;
1511
+ if (to_range.get(0) == from_range.get(0) &&
1512
+ to_range.get(0) == size.get(0))
1513
+ {
1514
+ event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,
1515
+ direction, dep_events));
1516
+ }
1517
+ else
1518
+ {
1519
+ for (size_t y = 0; y < size.get(1); ++y)
1520
+ {
1521
+ event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),
1522
+ direction, dep_events));
1523
+ to_ptr += to_range.get(0);
1524
+ from_ptr += from_range.get(0);
1525
+ }
1526
+ }
1527
+ to_surface += to_slice;
1528
+ from_surface += from_slice;
1529
+ }
1530
+ break;
1531
+ case host_to_device:
1532
+ {
1533
+ host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,
1534
+ event_list);
1535
+ std::vector<sycl::event> host_events;
1536
+ if (to_slice == size_slice)
1537
+ {
1538
+ // Copy host data to a temp host buffer with the shape of target.
1539
+ host_events =
1540
+ dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,
1541
+ sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,
1542
+ host_to_host, dep_events);
1543
+ }
1544
+ else
1545
+ {
1546
+ // Copy host data to a temp host buffer with the shape of target.
1547
+ host_events = dpct_memcpy(
1548
+ q, buf.get_ptr(), from_surface, to_range, from_range,
1549
+ sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,
1550
+ // If has padding data, not sure whether it is useless. So fill temp
1551
+ // buffer with it.
1552
+ std::vector<sycl::event>{
1553
+ dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),
1554
+ device_to_host, dep_events)});
1555
+ }
1556
+ // Copy from temp host buffer to device with only one submit.
1557
+ event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),
1558
+ buf.get_size(), host_to_device,
1559
+ host_events));
1560
+ break;
1561
+ }
1562
+ case device_to_host:
1563
+ {
1564
+ host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,
1565
+ event_list);
1566
+ // Copy from host temp buffer to host target with reshaping.
1567
+ event_list = dpct_memcpy(
1568
+ q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),
1569
+ sycl::id<3>(0, 0, 0), size, host_to_host,
1570
+ // Copy from device to temp host buffer with only one submit.
1571
+ std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,
1572
+ buf.get_size(),
1573
+ device_to_host, dep_events)});
1574
+ break;
1575
+ }
1576
+ case device_to_device:
1577
+ event_list.push_back(q.submit([&](sycl::handler &cgh){
1578
+ cgh.depends_on(dep_events);
1579
+ cgh.parallel_for<class dpct_memcpy_3d_detail>(
1580
+ size,
1581
+ [=](sycl::id<3> id) {
1582
+ to_surface[get_offset(id, to_slice, to_range.get(0))] =
1583
+ from_surface[get_offset(id, from_slice, from_range.get(0))];
1584
+ }); }));
1585
+ break;
1586
+ default:
1587
+ throw std::runtime_error("dpct_memcpy: invalid direction value");
1588
+ }
1589
+ return event_list;
1590
+ }
1591
+
1592
+ /// memcpy 2D/3D matrix specified by pitched_data.
1593
+ static inline std::vector<sycl::event>
1594
+ dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,
1595
+ pitched_data from, sycl::id<3> from_id, sycl::range<3> size,
1596
+ memcpy_direction direction = automatic)
1597
+ {
1598
+ return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),
1599
+ sycl::range<3>(to.get_pitch(), to.get_y(), 1),
1600
+ sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,
1601
+ size, direction);
1602
+ }
1603
+
1604
+ /// memcpy 2D matrix with pitch.
1605
+ static inline std::vector<sycl::event>
1606
+ dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
1607
+ size_t to_pitch, size_t from_pitch, size_t x, size_t y,
1608
+ memcpy_direction direction = automatic)
1609
+ {
1610
+ return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),
1611
+ sycl::range<3>(from_pitch, y, 1),
1612
+ sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),
1613
+ sycl::range<3>(x, y, 1), direction);
1614
+ }
1615
+
1616
+ namespace deprecated
1617
+ {
1618
+
1619
+ template <typename T, sycl::usm::alloc AllocKind>
1620
+ class usm_allocator
1621
+ {
1622
+ private:
1623
+ using Alloc = sycl::usm_allocator<T, AllocKind>;
1624
+ Alloc _impl;
1625
+
1626
+ public:
1627
+ using value_type = typename std::allocator_traits<Alloc>::value_type;
1628
+ using pointer = typename std::allocator_traits<Alloc>::pointer;
1629
+ using const_pointer = typename std::allocator_traits<Alloc>::const_pointer;
1630
+ using void_pointer = typename std::allocator_traits<Alloc>::void_pointer;
1631
+ using const_void_pointer =
1632
+ typename std::allocator_traits<Alloc>::const_void_pointer;
1633
+ using reference = typename std::allocator_traits<Alloc>::value_type &;
1634
+ using const_reference =
1635
+ const typename std::allocator_traits<Alloc>::value_type &;
1636
+ using difference_type =
1637
+ typename std::allocator_traits<Alloc>::difference_type;
1638
+ using size_type = typename std::allocator_traits<Alloc>::size_type;
1639
+ using propagate_on_container_copy_assignment = typename std::allocator_traits<
1640
+ Alloc>::propagate_on_container_copy_assignment;
1641
+ using propagate_on_container_move_assignment = typename std::allocator_traits<
1642
+ Alloc>::propagate_on_container_move_assignment;
1643
+ using propagate_on_container_swap =
1644
+ typename std::allocator_traits<Alloc>::propagate_on_container_swap;
1645
+ using is_always_equal =
1646
+ typename std::allocator_traits<Alloc>::is_always_equal;
1647
+
1648
+ template <typename U>
1649
+ struct rebind
1650
+ {
1651
+ typedef usm_allocator<U, AllocKind> other;
1652
+ };
1653
+
1654
+ usm_allocator() : _impl(dpct::get_default_queue()) {}
1655
+ ~usm_allocator() {}
1656
+ usm_allocator(const usm_allocator &other) : _impl(other._impl) {}
1657
+ usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {}
1658
+ pointer address(reference r) { return &r; }
1659
+ const_pointer address(const_reference r) { return &r; }
1660
+ pointer allocate(size_type cnt, const_void_pointer hint = nullptr)
1661
+ {
1662
+ return std::allocator_traits<Alloc>::allocate(_impl, cnt, hint);
1663
+ }
1664
+ void deallocate(pointer p, size_type cnt)
1665
+ {
1666
+ std::allocator_traits<Alloc>::deallocate(_impl, p, cnt);
1667
+ }
1668
+ size_type max_size() const
1669
+ {
1670
+ return std::allocator_traits<Alloc>::max_size(_impl);
1671
+ }
1672
+ bool operator==(const usm_allocator &other) const { return _impl == other._impl; }
1673
+ bool operator!=(const usm_allocator &other) const { return _impl != other._impl; }
1674
+ };
1675
+
1676
+ } // namespace deprecated
1677
+
1678
+ inline void dpct_free(void *ptr,
1679
+ const sycl::queue &q)
1680
+ {
1681
+ if (ptr)
1682
+ {
1683
+ sycl::free(ptr, q.get_context());
1684
+ }
1685
+ }
1686
+
1687
+ template <typename T>
1688
+ inline auto get_memory(const void *x)
1689
+ {
1690
+ T *new_x = reinterpret_cast<T *>(const_cast<void *>(x));
1691
+ return new_x;
1692
+ }
1693
+
1694
+ template <typename T>
1695
+ inline typename DataType<T>::T2 get_value(const T *s, sycl::queue &q)
1696
+ {
1697
+ using Ty = typename DataType<T>::T2;
1698
+ Ty s_h;
1699
+ if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only)
1700
+ detail::dpct_memcpy(q, (void *)&s_h, (const void *)s, sizeof(T), device_to_host)
1701
+ .wait();
1702
+ else
1703
+ s_h = *reinterpret_cast<const Ty *>(s);
1704
+ return s_h;
1705
+ }
1706
+
1707
+ } // namespace detail
1708
+
1709
+ template <typename T>
1710
+ inline auto get_value(const T *s, sycl::queue &q)
1711
+ {
1712
+ return detail::get_value(s, q);
1713
+ }
1714
+
1715
+ namespace detail
1716
+ {
1717
+ template <class Ta, class Tb, class Tc, class Ts>
1718
+ inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
1719
+ int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
1720
+ const void * beta, void * c, int ldc) {
1721
+ Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1722
+ Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1723
+ auto data_a = get_memory<const Ta>(a);
1724
+ auto data_b = get_memory<const Tb>(b);
1725
+ auto data_c = get_memory<Tc>(c);
1726
+ oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
1727
+ lda, data_b, ldb, beta_value, data_c, ldc);
1728
+ }
1729
+
1730
+ template <typename VecT, class BinaryOperation, class = void>
1731
+ class vectorized_binary
1732
+ {
1733
+ public:
1734
+ inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
1735
+ {
1736
+ VecT v4;
1737
+ for (size_t i = 0; i < v4.size(); ++i)
1738
+ {
1739
+ v4[i] = binary_op(a[i], b[i]);
1740
+ }
1741
+ return v4;
1742
+ }
1743
+ };
1744
+
1745
+ template <typename VecT, class BinaryOperation>
1746
+ class vectorized_binary<
1747
+ VecT, BinaryOperation,
1748
+ std::void_t<std::invoke_result_t<BinaryOperation, VecT, VecT>>>
1749
+ {
1750
+ public:
1751
+ inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
1752
+ {
1753
+ return binary_op(a, b).template as<VecT>();
1754
+ }
1755
+ };
1756
+
1757
+ template <class Ta, class Tb, class Tc, class Ts>
1758
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
1759
+ int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1760
+ int ldb, const void * beta, void ** c, int ldc, int batch_size,
1761
+ matrix_info_t<float> * matrix_info) {
1762
+ Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1763
+ Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1764
+
1765
+ matrix_info->transpose_info[0] = a_trans;
1766
+ matrix_info->transpose_info[1] = b_trans;
1767
+ matrix_info->value_info[0] = alpha_value;
1768
+ matrix_info->value_info[1] = beta_value;
1769
+ matrix_info->size_info[0] = m;
1770
+ matrix_info->size_info[1] = n;
1771
+ matrix_info->size_info[2] = k;
1772
+ matrix_info->ld_info[0] = lda;
1773
+ matrix_info->ld_info[1] = ldb;
1774
+ matrix_info->ld_info[2] = ldc;
1775
+ matrix_info->groupsize_info = batch_size;
1776
+
1777
+ sycl::event e = oneapi::math::blas::column_major::gemm_batch(
1778
+ get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
1779
+ matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
1780
+ reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
1781
+ reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1782
+ reinterpret_cast<Ts *>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
1783
+ matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1784
+ }
1785
+
1786
+ template <class Ta, class Tb, class Tc, class Ts>
1787
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
1788
+ int m, int n, int k, const void * alpha, const void * a, int lda,
1789
+ long long int stride_a, const void * b, int ldb, long long int stride_b,
1790
+ const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
1791
+ Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1792
+ Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1793
+ auto data_a = get_memory<const Ta>(a);
1794
+ auto data_b = get_memory<const Tb>(b);
1795
+ auto data_c = get_memory<Tc>(c);
1796
+ oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
1797
+ data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
1798
+ data_c, ldc, stride_c, batch_size);
1799
+ }
1800
+
1801
+ } // namespace detail
1802
+
1803
+ template <typename VecT, class BinaryOperation>
1804
+ inline unsigned vectorized_binary(unsigned a, unsigned b,
1805
+ const BinaryOperation binary_op)
1806
+ {
1807
+ sycl::vec<unsigned, 1> v0{a}, v1{b};
1808
+ auto v2 = v0.as<VecT>();
1809
+ auto v3 = v1.as<VecT>();
1810
+ auto v4 =
1811
+ detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
1812
+ v0 = v4.template as<sycl::vec<unsigned, 1>>();
1813
+ return v0;
1814
+ }
1815
+
1816
+ static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size,
1817
+ memcpy_direction direction = automatic,
1818
+ sycl::queue &q = dpct::get_default_queue())
1819
+ {
1820
+ detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction);
1821
+ }
1822
+
1823
+ static inline unsigned int select_device(unsigned int id)
1824
+ {
1825
+ dev_mgr::instance().select_device(id);
1826
+ return id;
1827
+ }
1828
+
1829
+ template <typename T>
1830
+ T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask,
1831
+ unsigned int logical_sub_group_size = 32)
1832
+ {
1833
+ unsigned int id = g.get_local_linear_id();
1834
+ unsigned int start_index =
1835
+ id / logical_sub_group_size * logical_sub_group_size;
1836
+ unsigned int target_offset = (id % logical_sub_group_size) ^ mask;
1837
+ return sycl::select_from_group(g, x,
1838
+ target_offset < logical_sub_group_size
1839
+ ? start_index + target_offset
1840
+ : id);
1841
+ }
1842
+
1843
+ template <typename T1, typename T2, typename T3>
1844
+ inline auto dp4a(T1 a, T2 b, T3 c)
1845
+ {
1846
+ return syclcompat::dp4a(a, b, c);
1847
+ }
1848
+
1849
+ struct sub_sat
1850
+ {
1851
+ template <typename T>
1852
+ auto operator()(const T x, const T y) const
1853
+ {
1854
+ return sycl::sub_sat(x, y);
1855
+ }
1856
+ };
1857
+
1858
+ template <typename S, typename T>
1859
+ inline T vectorized_min(T a, T b)
1860
+ {
1861
+ sycl::vec<T, 1> v0{a}, v1{b};
1862
+ auto v2 = v0.template as<S>();
1863
+ auto v3 = v1.template as<S>();
1864
+ auto v4 = sycl::min(v2, v3);
1865
+ v0 = v4.template as<sycl::vec<T, 1>>();
1866
+ return v0;
1867
+ }
1868
+
1869
+ inline float pow(const float a, const int b) { return sycl::pown(a, b); }
1870
+ inline double pow(const double a, const int b) { return sycl::pown(a, b); }
1871
+ inline float pow(const float a, const float b) { return sycl::pow(a, b); }
1872
+ inline double pow(const double a, const double b) { return sycl::pow(a, b); }
1873
+ template <typename T, typename U>
1874
+ inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
1875
+ pow(const T a, const U b)
1876
+ {
1877
+ return sycl::pow(a, static_cast<T>(b));
1878
+ }
1879
+ template <typename T, typename U>
1880
+ inline typename std::enable_if_t<!std::is_floating_point_v<T>, double>
1881
+ pow(const T a, const U b)
1882
+ {
1883
+ return sycl::pow(static_cast<double>(a), static_cast<double>(b));
1884
+ }
1885
+
1886
+ inline double min(const double a, const float b)
1887
+ {
1888
+ return sycl::fmin(a, static_cast<double>(b));
1889
+ }
1890
+ inline double min(const float a, const double b)
1891
+ {
1892
+ return sycl::fmin(static_cast<double>(a), b);
1893
+ }
1894
+ inline float min(const float a, const float b) { return sycl::fmin(a, b); }
1895
+ inline double min(const double a, const double b) { return sycl::fmin(a, b); }
1896
+ inline std::uint32_t min(const std::uint32_t a, const std::int32_t b)
1897
+ {
1898
+ return sycl::min(a, static_cast<std::uint32_t>(b));
1899
+ }
1900
+ inline std::uint32_t min(const std::int32_t a, const std::uint32_t b)
1901
+ {
1902
+ return sycl::min(static_cast<std::uint32_t>(a), b);
1903
+ }
1904
+ inline std::int32_t min(const std::int32_t a, const std::int32_t b)
1905
+ {
1906
+ return sycl::min(a, b);
1907
+ }
1908
+ inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b)
1909
+ {
1910
+ return sycl::min(a, b);
1911
+ }
1912
+ inline std::uint64_t min(const std::uint64_t a, const std::int64_t b)
1913
+ {
1914
+ return sycl::min(a, static_cast<std::uint64_t>(b));
1915
+ }
1916
+ inline std::uint64_t min(const std::int64_t a, const std::uint64_t b)
1917
+ {
1918
+ return sycl::min(static_cast<std::uint64_t>(a), b);
1919
+ }
1920
+ inline std::int64_t min(const std::int64_t a, const std::int64_t b)
1921
+ {
1922
+ return sycl::min(a, b);
1923
+ }
1924
+ inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b)
1925
+ {
1926
+ return sycl::min(a, b);
1927
+ }
1928
+ inline std::uint64_t min(const std::uint64_t a, const std::int32_t b)
1929
+ {
1930
+ return sycl::min(a, static_cast<std::uint64_t>(b));
1931
+ }
1932
+ inline std::uint64_t min(const std::int32_t a, const std::uint64_t b)
1933
+ {
1934
+ return sycl::min(static_cast<std::uint64_t>(a), b);
1935
+ }
1936
+ inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b)
1937
+ {
1938
+ return sycl::min(a, static_cast<std::uint64_t>(b));
1939
+ }
1940
+ inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b)
1941
+ {
1942
+ return sycl::min(static_cast<std::uint64_t>(a), b);
1943
+ }
1944
+ // max function overloads.
1945
+ // For floating-point types, `float` or `double` arguments are acceptable.
1946
+ // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
1947
+ // `std::int64_t` type arguments are acceptable.
1948
+ inline double max(const double a, const float b)
1949
+ {
1950
+ return sycl::fmax(a, static_cast<double>(b));
1951
+ }
1952
+ inline double max(const float a, const double b)
1953
+ {
1954
+ return sycl::fmax(static_cast<double>(a), b);
1955
+ }
1956
+ inline float max(const float a, const float b) { return sycl::fmax(a, b); }
1957
+ inline double max(const double a, const double b) { return sycl::fmax(a, b); }
1958
+ inline std::uint32_t max(const std::uint32_t a, const std::int32_t b)
1959
+ {
1960
+ return sycl::max(a, static_cast<std::uint32_t>(b));
1961
+ }
1962
+ inline std::uint32_t max(const std::int32_t a, const std::uint32_t b)
1963
+ {
1964
+ return sycl::max(static_cast<std::uint32_t>(a), b);
1965
+ }
1966
+ inline std::int32_t max(const std::int32_t a, const std::int32_t b)
1967
+ {
1968
+ return sycl::max(a, b);
1969
+ }
1970
+ inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b)
1971
+ {
1972
+ return sycl::max(a, b);
1973
+ }
1974
+ inline std::uint64_t max(const std::uint64_t a, const std::int64_t b)
1975
+ {
1976
+ return sycl::max(a, static_cast<std::uint64_t>(b));
1977
+ }
1978
+ inline std::uint64_t max(const std::int64_t a, const std::uint64_t b)
1979
+ {
1980
+ return sycl::max(static_cast<std::uint64_t>(a), b);
1981
+ }
1982
+ inline std::int64_t max(const std::int64_t a, const std::int64_t b)
1983
+ {
1984
+ return sycl::max(a, b);
1985
+ }
1986
+ inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b)
1987
+ {
1988
+ return sycl::max(a, b);
1989
+ }
1990
+ inline std::uint64_t max(const std::uint64_t a, const std::int32_t b)
1991
+ {
1992
+ return sycl::max(a, static_cast<std::uint64_t>(b));
1993
+ }
1994
+ inline std::uint64_t max(const std::int32_t a, const std::uint64_t b)
1995
+ {
1996
+ return sycl::max(static_cast<std::uint64_t>(a), b);
1997
+ }
1998
+ inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b)
1999
+ {
2000
+ return sycl::max(a, static_cast<std::uint64_t>(b));
2001
+ }
2002
+ inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b)
2003
+ {
2004
+ return sycl::max(static_cast<std::uint64_t>(a), b);
2005
+ }
2006
+
2007
+ inline void
2008
+ has_capability_or_fail(const sycl::device &dev,
2009
+ const std::initializer_list<sycl::aspect> &props)
2010
+ {
2011
+ for (const auto &it : props)
2012
+ {
2013
+ if (dev.has(it))
2014
+ continue;
2015
+ switch (it)
2016
+ {
2017
+ case sycl::aspect::fp64:
2018
+ throw std::runtime_error("'double' is not supported in '" +
2019
+ dev.get_info<sycl::info::device::name>() +
2020
+ "' device");
2021
+ break;
2022
+ case sycl::aspect::fp16:
2023
+ throw std::runtime_error("'half' is not supported in '" +
2024
+ dev.get_info<sycl::info::device::name>() +
2025
+ "' device");
2026
+ break;
2027
+ default:
2028
+ #define __SYCL_ASPECT(ASPECT, ID) \
2029
+ case sycl::aspect::ASPECT: \
2030
+ return #ASPECT;
2031
+ #define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID)
2032
+ #define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE)
2033
+ auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string
2034
+ {
2035
+ switch (AspectNum)
2036
+ {
2037
+ #include <sycl/info/aspects.def>
2038
+ #include <sycl/info/aspects_deprecated.def>
2039
+ default:
2040
+ return "unknown aspect";
2041
+ }
2042
+ };
2043
+ #undef __SYCL_ASPECT_DEPRECATED_ALIAS
2044
+ #undef __SYCL_ASPECT_DEPRECATED
2045
+ #undef __SYCL_ASPECT
2046
+ throw std::runtime_error(
2047
+ "'" + getAspectNameStr(it) + "' is not supported in '" +
2048
+ dev.get_info<sycl::info::device::name>() + "' device");
2049
+ }
2050
+ break;
2051
+ }
2052
+ }
2053
+
2054
+ static inline unsigned int get_current_device_id()
2055
+ {
2056
+ return dev_mgr::instance().current_device_id();
2057
+ }
2058
+
2059
+ static inline device_ext &get_current_device()
2060
+ {
2061
+ return dev_mgr::instance().current_device();
2062
+ }
2063
+
2064
+ static inline device_ext &get_device(unsigned int id)
2065
+ {
2066
+ return dev_mgr::instance().get_device(id);
2067
+ }
2068
+
2069
+ static inline sycl::queue &get_in_order_queue()
2070
+ {
2071
+ return dev_mgr::instance().current_device().in_order_queue();
2072
+ }
2073
+
2074
+ static sycl::event
2075
+ dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,
2076
+ memcpy_direction direction,
2077
+ const std::vector<sycl::event> &dep_events = {})
2078
+ {
2079
+ if (!size)
2080
+ return sycl::event{};
2081
+ return q.memcpy(to_ptr, from_ptr, size, dep_events);
2082
+ GGML_UNUSED(direction);
2083
+ }
2084
+
2085
+ // Get actual copy range and make sure it will not exceed range.
2086
+ static inline size_t get_copy_range(sycl::range<3> size, size_t slice,
2087
+ size_t pitch)
2088
+ {
2089
+ return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);
2090
+ }
2091
+
2092
+ static inline size_t get_offset(sycl::id<3> id, size_t slice,
2093
+ size_t pitch)
2094
+ {
2095
+ return slice * id.get(2) + pitch * id.get(1) + id.get(0);
2096
+ }
2097
+
2098
+ /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr
2099
+ /// and \p from_range to another specified by \p to_ptr and \p to_range.
2100
+ static inline std::vector<sycl::event>
2101
+ dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
2102
+ sycl::range<3> to_range, sycl::range<3> from_range,
2103
+ sycl::id<3> to_id, sycl::id<3> from_id,
2104
+ sycl::range<3> size, memcpy_direction direction,
2105
+ const std::vector<sycl::event> &dep_events = {})
2106
+ {
2107
+ // RAII for host pointer
2108
+ class host_buffer
2109
+ {
2110
+ void *_buf;
2111
+ size_t _size;
2112
+ sycl::queue &_q;
2113
+ const std::vector<sycl::event> &_deps; // free operation depends
2114
+
2115
+ public:
2116
+ host_buffer(size_t size, sycl::queue &q,
2117
+ const std::vector<sycl::event> &deps)
2118
+ : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}
2119
+ void *get_ptr() const { return _buf; }
2120
+ size_t get_size() const { return _size; }
2121
+ ~host_buffer()
2122
+ {
2123
+ if (_buf)
2124
+ {
2125
+ _q.submit([&](sycl::handler &cgh)
2126
+ {
2127
+ cgh.depends_on(_deps);
2128
+ cgh.host_task([buf = _buf] { std::free(buf); }); });
2129
+ }
2130
+ }
2131
+ };
2132
+ std::vector<sycl::event> event_list;
2133
+
2134
+ size_t to_slice = to_range.get(1) * to_range.get(0),
2135
+ from_slice = from_range.get(1) * from_range.get(0);
2136
+ unsigned char *to_surface =
2137
+ (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));
2138
+ const unsigned char *from_surface =
2139
+ (const unsigned char *)from_ptr +
2140
+ get_offset(from_id, from_slice, from_range.get(0));
2141
+
2142
+ if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))
2143
+ {
2144
+ return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),
2145
+ direction, dep_events)};
2146
+ }
2147
+ direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
2148
+ size_t size_slice = size.get(1) * size.get(0);
2149
+ switch (direction)
2150
+ {
2151
+ case host_to_host:
2152
+ for (size_t z = 0; z < size.get(2); ++z)
2153
+ {
2154
+ unsigned char *to_ptr = to_surface;
2155
+ const unsigned char *from_ptr = from_surface;
2156
+ if (to_range.get(0) == from_range.get(0) &&
2157
+ to_range.get(0) == size.get(0))
2158
+ {
2159
+ event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,
2160
+ direction, dep_events));
2161
+ }
2162
+ else
2163
+ {
2164
+ for (size_t y = 0; y < size.get(1); ++y)
2165
+ {
2166
+ event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),
2167
+ direction, dep_events));
2168
+ to_ptr += to_range.get(0);
2169
+ from_ptr += from_range.get(0);
2170
+ }
2171
+ }
2172
+ to_surface += to_slice;
2173
+ from_surface += from_slice;
2174
+ }
2175
+ break;
2176
+ case host_to_device:
2177
+ {
2178
+ host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,
2179
+ event_list);
2180
+ std::vector<sycl::event> host_events;
2181
+ if (to_slice == size_slice)
2182
+ {
2183
+ // Copy host data to a temp host buffer with the shape of target.
2184
+ host_events =
2185
+ dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,
2186
+ sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,
2187
+ host_to_host, dep_events);
2188
+ }
2189
+ else
2190
+ {
2191
+ // Copy host data to a temp host buffer with the shape of target.
2192
+ host_events = dpct_memcpy(
2193
+ q, buf.get_ptr(), from_surface, to_range, from_range,
2194
+ sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,
2195
+ // If has padding data, not sure whether it is useless. So fill temp
2196
+ // buffer with it.
2197
+ std::vector<sycl::event>{
2198
+ dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),
2199
+ device_to_host, dep_events)});
2200
+ }
2201
+ // Copy from temp host buffer to device with only one submit.
2202
+ event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),
2203
+ buf.get_size(), host_to_device,
2204
+ host_events));
2205
+ break;
2206
+ }
2207
+ case device_to_host:
2208
+ {
2209
+ host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,
2210
+ event_list);
2211
+ // Copy from host temp buffer to host target with reshaping.
2212
+ event_list = dpct_memcpy(
2213
+ q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),
2214
+ sycl::id<3>(0, 0, 0), size, host_to_host,
2215
+ // Copy from device to temp host buffer with only one submit.
2216
+ std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,
2217
+ buf.get_size(),
2218
+ device_to_host, dep_events)});
2219
+ break;
2220
+ }
2221
+ case device_to_device:
2222
+ event_list.push_back(q.submit([&](sycl::handler &cgh)
2223
+ {
2224
+ cgh.depends_on(dep_events);
2225
+ cgh.parallel_for<class dpct_memcpy_3d_detail>(
2226
+ size,
2227
+ [=](sycl::id<3> id) {
2228
+ to_surface[get_offset(id, to_slice, to_range.get(0))] =
2229
+ from_surface[get_offset(id, from_slice, from_range.get(0))];
2230
+ }); }));
2231
+ break;
2232
+ default:
2233
+ throw std::runtime_error("dpct_memcpy: invalid direction value");
2234
+ }
2235
+ return event_list;
2236
+ }
2237
+
2238
+ /// memcpy 2D/3D matrix specified by pitched_data.
2239
+ static inline std::vector<sycl::event>
2240
+ dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,
2241
+ pitched_data from, sycl::id<3> from_id, sycl::range<3> size,
2242
+ memcpy_direction direction = automatic)
2243
+ {
2244
+ return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),
2245
+ sycl::range<3>(to.get_pitch(), to.get_y(), 1),
2246
+ sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,
2247
+ size, direction);
2248
+ }
2249
+
2250
+ /// memcpy 2D matrix with pitch.
2251
+ static inline std::vector<sycl::event>
2252
+ dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
2253
+ size_t to_pitch, size_t from_pitch, size_t x, size_t y,
2254
+ memcpy_direction direction = automatic)
2255
+ {
2256
+ return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),
2257
+ sycl::range<3>(from_pitch, y, 1),
2258
+ sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),
2259
+ sycl::range<3>(x, y, 1), direction);
2260
+ }
2261
+
2262
+ inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
2263
+ int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
2264
+ library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
2265
+ library_data_t scaling_type) {
2266
+ if (scaling_type == library_data_t::real_float &&
2267
+ c_type == library_data_t::complex_float)
2268
+ {
2269
+ scaling_type = library_data_t::complex_float;
2270
+ }
2271
+ else if (scaling_type == library_data_t::real_double &&
2272
+ c_type == library_data_t::complex_double)
2273
+ {
2274
+ scaling_type = library_data_t::complex_double;
2275
+ }
2276
+
2277
+ std::uint64_t key =
2278
+ detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
2279
+ switch (key)
2280
+ {
2281
+ case detail::get_type_combination_id(
2282
+ library_data_t::real_float, library_data_t::real_float,
2283
+ library_data_t::real_float, library_data_t::real_float):
2284
+ {
2285
+ detail::gemm_impl<float, float, float, float>(
2286
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2287
+ break;
2288
+ }
2289
+ case detail::get_type_combination_id(
2290
+ library_data_t::real_double, library_data_t::real_double,
2291
+ library_data_t::real_double, library_data_t::real_double):
2292
+ {
2293
+ detail::gemm_impl<double, double, double, double>(
2294
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2295
+ break;
2296
+ }
2297
+ case detail::get_type_combination_id(
2298
+ library_data_t::complex_float, library_data_t::complex_float,
2299
+ library_data_t::complex_float, library_data_t::complex_float):
2300
+ {
2301
+ detail::gemm_impl<std::complex<float>, std::complex<float>,
2302
+ std::complex<float>, std::complex<float>>(
2303
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2304
+ break;
2305
+ }
2306
+ case detail::get_type_combination_id(
2307
+ library_data_t::complex_double, library_data_t::complex_double,
2308
+ library_data_t::complex_double, library_data_t::complex_double):
2309
+ {
2310
+ detail::gemm_impl<std::complex<double>, std::complex<double>,
2311
+ std::complex<double>, std::complex<double>>(
2312
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2313
+ break;
2314
+ }
2315
+ case detail::get_type_combination_id(
2316
+ library_data_t::real_half, library_data_t::real_half,
2317
+ library_data_t::real_half, library_data_t::real_half):
2318
+ {
2319
+ detail::gemm_impl<sycl::half, sycl::half, sycl::half,
2320
+ sycl::half>(q, a_trans, b_trans, m, n, k, alpha, a,
2321
+ lda, b, ldb, beta, c, ldc);
2322
+ break;
2323
+ }
2324
+ #ifdef __INTEL_MKL__
2325
+ case detail::get_type_combination_id(
2326
+ library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2327
+ library_data_t::real_float, library_data_t::real_float):
2328
+ {
2329
+ detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2330
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2331
+ break;
2332
+ }
2333
+ case detail::get_type_combination_id(
2334
+ library_data_t::real_half, library_data_t::real_half,
2335
+ library_data_t::real_float, library_data_t::real_float):
2336
+ {
2337
+ detail::gemm_impl<sycl::half, sycl::half, float, float>(
2338
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2339
+ break;
2340
+ }
2341
+ case detail::get_type_combination_id(
2342
+ library_data_t::real_half, library_data_t::real_half,
2343
+ library_data_t::real_half, library_data_t::real_float):
2344
+ {
2345
+ float alpha_value =
2346
+ dpct::get_value(reinterpret_cast<const float *>(alpha), q);
2347
+ float beta_value =
2348
+ dpct::get_value(reinterpret_cast<const float *>(beta), q);
2349
+ sycl::half alpha_half(alpha_value);
2350
+ sycl::half beta_half(beta_value);
2351
+ detail::gemm_impl<sycl::half, sycl::half, sycl::half,
2352
+ sycl::half>(q, a_trans, b_trans, m, n, k, &alpha_half,
2353
+ a, lda, b, ldb, &beta_half, c, ldc);
2354
+ break;
2355
+ }
2356
+ case detail::get_type_combination_id(
2357
+ library_data_t::real_int8, library_data_t::real_int8,
2358
+ library_data_t::real_float, library_data_t::real_float):
2359
+ {
2360
+ detail::gemm_impl<std::int8_t, std::int8_t, float, float>(
2361
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2362
+ break;
2363
+ }
2364
+ case detail::get_type_combination_id(
2365
+ library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2366
+ library_data_t::real_bfloat16, library_data_t::real_float):
2367
+ {
2368
+ detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2369
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2370
+ break;
2371
+ }
2372
+ case detail::get_type_combination_id(
2373
+ library_data_t::real_int8, library_data_t::real_int8,
2374
+ library_data_t::real_int32, library_data_t::real_int32):
2375
+ {
2376
+ float alpha_float =
2377
+ dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
2378
+ float beta_float =
2379
+ dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
2380
+ detail::gemm_impl<std::int8_t, std::int8_t, std::int32_t, float>(
2381
+ q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
2382
+ break;
2383
+ }
2384
+ #endif // __INTEL_MKL__
2385
+ default:
2386
+ throw std::runtime_error("the combination of data type is unsupported");
2387
+ }
2388
+ } // gemm()
2389
+
2390
+ /// Computes a batch of matrix-matrix product with general matrices.
2391
+ /// \param [in] q The queue where the routine should be executed.
2392
+ /// \param [in] a_trans Specifies the operation applied to A.
2393
+ /// \param [in] b_trans Specifies the operation applied to B.
2394
+ /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.
2395
+ /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.
2396
+ /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).
2397
+ /// \param [in] alpha Scaling factor for the matrix-matrix product.
2398
+ /// \param [in] a Input matrix A.
2399
+ /// \param [in] a_type Data type of the matrix A.
2400
+ /// \param [in] lda Leading dimension of A.
2401
+ /// \param [in] b Input matrix B.
2402
+ /// \param [in] b_type Data type of the matrix B.
2403
+ /// \param [in] ldb Leading dimension of B.
2404
+ /// \param [in] beta Scaling factor for matrix C.
2405
+ /// \param [in, out] c Input/Output matrix C.
2406
+ /// \param [in] c_type Data type of the matrix C.
2407
+ /// \param [in] ldc Leading dimension of C.
2408
+ /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2409
+ /// \param [in] scaling_type Data type of the scaling factors.
2410
+ inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
2411
+ int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2412
+ const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2413
+ library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
2414
+ matrix_info_t<float> * matrix_info) {
2415
+ std::uint64_t key =
2416
+ detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
2417
+ switch (key)
2418
+ {
2419
+ case detail::get_type_combination_id(
2420
+ library_data_t::real_float, library_data_t::real_float,
2421
+ library_data_t::real_float, library_data_t::real_float):
2422
+ {
2423
+ detail::gemm_batch_impl<float, float, float, float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2424
+ beta, c, ldc, batch_size, matrix_info);
2425
+ break;
2426
+ }
2427
+ case detail::get_type_combination_id(
2428
+ library_data_t::real_double, library_data_t::real_double,
2429
+ library_data_t::real_double, library_data_t::real_double):
2430
+ {
2431
+ detail::gemm_batch_impl<double, double, double, double>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb,
2432
+ beta, c, ldc, batch_size, matrix_info);
2433
+ break;
2434
+ }
2435
+ case detail::get_type_combination_id(
2436
+ library_data_t::real_half, library_data_t::real_half,
2437
+ library_data_t::real_half, library_data_t::real_half):
2438
+ {
2439
+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2440
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2441
+ break;
2442
+ }
2443
+ #ifdef __INTEL_MKL__
2444
+ case detail::get_type_combination_id(
2445
+ library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2446
+ library_data_t::real_bfloat16, library_data_t::real_float):
2447
+ {
2448
+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2449
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2450
+ break;
2451
+ }
2452
+ case detail::get_type_combination_id(
2453
+ library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2454
+ library_data_t::real_float, library_data_t::real_float):
2455
+ {
2456
+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2457
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2458
+ break;
2459
+ }
2460
+ #endif
2461
+ case detail::get_type_combination_id(
2462
+ library_data_t::real_int8, library_data_t::real_int8,
2463
+ library_data_t::real_int32, library_data_t::real_int32):
2464
+ {
2465
+ float alpha_float =
2466
+ dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
2467
+ float beta_float =
2468
+ dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
2469
+ detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t, float>(
2470
+ q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size,
2471
+ matrix_info);
2472
+ break;
2473
+ }
2474
+ case detail::get_type_combination_id(
2475
+ library_data_t::real_int8, library_data_t::real_int8,
2476
+ library_data_t::real_float, library_data_t::real_float):
2477
+ {
2478
+ detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
2479
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2480
+ break;
2481
+ }
2482
+ case detail::get_type_combination_id(
2483
+ library_data_t::real_half, library_data_t::real_half,
2484
+ library_data_t::real_float, library_data_t::real_float):
2485
+ {
2486
+ detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
2487
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2488
+ break;
2489
+ }
2490
+ case detail::get_type_combination_id(
2491
+ library_data_t::real_half, library_data_t::real_half,
2492
+ library_data_t::real_half, library_data_t::real_float):
2493
+ {
2494
+ float alpha_value =
2495
+ dpct::get_value(reinterpret_cast<const float *>(alpha), q);
2496
+ float beta_value =
2497
+ dpct::get_value(reinterpret_cast<const float *>(beta), q);
2498
+ sycl::half alpha_half(alpha_value);
2499
+ sycl::half beta_half(beta_value);
2500
+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2501
+ q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info);
2502
+ break;
2503
+ }
2504
+ default:
2505
+ throw std::runtime_error("the combination of data type is unsupported");
2506
+ }
2507
+ }
2508
+
2509
+ /// Computes a batch of matrix-matrix product with general matrices.
2510
+ /// \param [in] q The queue where the routine should be executed.
2511
+ /// \param [in] a_trans Specifies the operation applied to A.
2512
+ /// \param [in] b_trans Specifies the operation applied to B.
2513
+ /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.
2514
+ /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.
2515
+ /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).
2516
+ /// \param [in] alpha Scaling factor for the matrix-matrix product.
2517
+ /// \param [in] a Input matrix A.
2518
+ /// \param [in] a_type Data type of the matrix A.
2519
+ /// \param [in] lda Leading dimension of A.
2520
+ /// \param [in] stride_a Stride between the different A matrices.
2521
+ /// \param [in] b Input matrix B.
2522
+ /// \param [in] b_type Data type of the matrix B.
2523
+ /// \param [in] ldb Leading dimension of B.
2524
+ /// \param [in] stride_b Stride between the different B matrices.
2525
+ /// \param [in] beta Scaling factor for matrix C.
2526
+ /// \param [in, out] c Input/Output matrix C.
2527
+ /// \param [in] c_type Data type of the matrix C.
2528
+ /// \param [in] ldc Leading dimension of C.
2529
+ /// \param [in] stride_c Stride between the different C matrices.
2530
+ /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2531
+ /// \param [in] scaling_type Data type of the scaling factors.
2532
+ inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
2533
+ int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
2534
+ long long int stride_a, const void * b, library_data_t b_type, int ldb,
2535
+ long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
2536
+ long long int stride_c, int batch_size, library_data_t scaling_type) {
2537
+ if (scaling_type == library_data_t::real_float &&
2538
+ c_type == library_data_t::complex_float)
2539
+ {
2540
+ scaling_type = library_data_t::complex_float;
2541
+ }
2542
+ else if (scaling_type == library_data_t::real_double &&
2543
+ c_type == library_data_t::complex_double)
2544
+ {
2545
+ scaling_type = library_data_t::complex_double;
2546
+ }
2547
+
2548
+ std::uint64_t key =
2549
+ detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
2550
+ switch (key)
2551
+ {
2552
+ case detail::get_type_combination_id(
2553
+ library_data_t::real_float, library_data_t::real_float,
2554
+ library_data_t::real_float, library_data_t::real_float):
2555
+ {
2556
+ detail::gemm_batch_impl<float, float, float, float>(
2557
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2558
+ beta, c, ldc, stride_c, batch_size);
2559
+ break;
2560
+ }
2561
+ case detail::get_type_combination_id(
2562
+ library_data_t::real_double, library_data_t::real_double,
2563
+ library_data_t::real_double, library_data_t::real_double):
2564
+ {
2565
+ detail::gemm_batch_impl<double, double, double, double>(
2566
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2567
+ beta, c, ldc, stride_c, batch_size);
2568
+ break;
2569
+ }
2570
+ case detail::get_type_combination_id(
2571
+ library_data_t::complex_float, library_data_t::complex_float,
2572
+ library_data_t::complex_float, library_data_t::complex_float):
2573
+ {
2574
+ detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
2575
+ std::complex<float>, std::complex<float>>(
2576
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2577
+ beta, c, ldc, stride_c, batch_size);
2578
+ break;
2579
+ }
2580
+ case detail::get_type_combination_id(
2581
+ library_data_t::complex_double, library_data_t::complex_double,
2582
+ library_data_t::complex_double, library_data_t::complex_double):
2583
+ {
2584
+ detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
2585
+ std::complex<double>, std::complex<double>>(
2586
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2587
+ beta, c, ldc, stride_c, batch_size);
2588
+ break;
2589
+ }
2590
+ case detail::get_type_combination_id(
2591
+ library_data_t::real_half, library_data_t::real_half,
2592
+ library_data_t::real_half, library_data_t::real_half):
2593
+ {
2594
+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
2595
+ sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
2596
+ a, lda, stride_a, b, ldb, stride_b,
2597
+ beta, c, ldc, stride_c, batch_size);
2598
+ break;
2599
+ }
2600
+ #ifdef __INTEL_MKL__
2601
+ case detail::get_type_combination_id(
2602
+ library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2603
+ library_data_t::real_bfloat16, library_data_t::real_float):
2604
+ {
2605
+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2606
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2607
+ batch_size);
2608
+ break;
2609
+ }
2610
+ case detail::get_type_combination_id(
2611
+ library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2612
+ library_data_t::real_float, library_data_t::real_float):
2613
+ {
2614
+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2615
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2616
+ batch_size);
2617
+ break;
2618
+ }
2619
+ #endif
2620
+ case detail::get_type_combination_id(
2621
+ library_data_t::real_int8, library_data_t::real_int8,
2622
+ library_data_t::real_int32, library_data_t::real_int32):
2623
+ {
2624
+ detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
2625
+ std::int32_t>(q, a_trans, b_trans, m, n, k, alpha,
2626
+ a, lda, stride_a, b, ldb, stride_b,
2627
+ beta, c, ldc, stride_c, batch_size);
2628
+ break;
2629
+ }
2630
+ case detail::get_type_combination_id(
2631
+ library_data_t::real_int8, library_data_t::real_int8,
2632
+ library_data_t::real_float, library_data_t::real_float):
2633
+ {
2634
+ detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
2635
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2636
+ beta, c, ldc, stride_c, batch_size);
2637
+ break;
2638
+ }
2639
+ case detail::get_type_combination_id(
2640
+ library_data_t::real_half, library_data_t::real_half,
2641
+ library_data_t::real_float, library_data_t::real_float):
2642
+ {
2643
+ detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
2644
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2645
+ beta, c, ldc, stride_c, batch_size);
2646
+ break;
2647
+ }
2648
+ case detail::get_type_combination_id(
2649
+ library_data_t::real_half, library_data_t::real_half,
2650
+ library_data_t::real_half, library_data_t::real_float):
2651
+ {
2652
+ float alpha_value =
2653
+ dpct::get_value(reinterpret_cast<const float *>(alpha), q);
2654
+ float beta_value =
2655
+ dpct::get_value(reinterpret_cast<const float *>(beta), q);
2656
+ sycl::half alpha_half(alpha_value);
2657
+ sycl::half beta_half(beta_value);
2658
+ detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
2659
+ q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b,
2660
+ &beta_half, c, ldc, stride_c, batch_size);
2661
+ break;
2662
+ }
2663
+ default:
2664
+ throw std::runtime_error("the combination of data type is unsupported");
2665
+ }
2666
+ }
2667
+
2668
+ static inline void
2669
+ async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr,
2670
+ size_t from_pitch, size_t x, size_t y,
2671
+ memcpy_direction direction = automatic,
2672
+ sycl::queue &q = get_default_queue())
2673
+ {
2674
+ detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y,
2675
+ direction);
2676
+ }
2677
+
2678
+ using err0 = detail::generic_error_type<struct err0_tag, int>;
2679
+ using err1 = detail::generic_error_type<struct err1_tag, int>;
2680
+
2681
+ static inline void dpct_free(void *ptr, sycl::queue &q = get_default_queue()) {
2682
+ detail::dpct_free(ptr, q);
2683
+ }
2684
+
2685
+ /// dpct accessor used as device function parameter.
2686
+ template <class T, memory_region Memory, size_t Dimension> class accessor;
2687
+ template <class T, memory_region Memory> class accessor<T, Memory, 3> {
2688
+ public:
2689
+ using memory_t = detail::memory_traits<Memory, T>;
2690
+ using element_t = typename memory_t::element_t;
2691
+ using pointer_t = typename memory_t::pointer_t;
2692
+ using accessor_t = typename memory_t::template accessor_t<3>;
2693
+ accessor(pointer_t data, const sycl::range<3> &in_range)
2694
+ : _data(data), _range(in_range) {}
2695
+ template <memory_region M = Memory>
2696
+ accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)
2697
+ : accessor(acc, acc.get_range()) {}
2698
+ accessor(const accessor_t &acc, const sycl::range<3> &in_range)
2699
+ : accessor(acc.get_pointer(), in_range) {}
2700
+ accessor<T, Memory, 2> operator[](size_t index) const {
2701
+ sycl::range<2> sub(_range.get(1), _range.get(2));
2702
+ return accessor<T, Memory, 2>(_data + index * sub.size(), sub);
2703
+ }
2704
+
2705
+ pointer_t get_ptr() const { return _data; }
2706
+
2707
+ private:
2708
+ pointer_t _data;
2709
+ sycl::range<3> _range;
2710
+ };
2711
+ template <class T, memory_region Memory> class accessor<T, Memory, 2> {
2712
+ public:
2713
+ using memory_t = detail::memory_traits<Memory, T>;
2714
+ using element_t = typename memory_t::element_t;
2715
+ using pointer_t = typename memory_t::pointer_t;
2716
+ using accessor_t = typename memory_t::template accessor_t<2>;
2717
+ accessor(pointer_t data, const sycl::range<2> &in_range)
2718
+ : _data(data), _range(in_range) {}
2719
+ template <memory_region M = Memory>
2720
+ accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)
2721
+ : accessor(acc, acc.get_range()) {}
2722
+ accessor(const accessor_t &acc, const sycl::range<2> &in_range)
2723
+ : accessor(acc.get_pointer(), in_range) {}
2724
+
2725
+ pointer_t operator[](size_t index) const {
2726
+ return _data + _range.get(1) * index;
2727
+ }
2728
+
2729
+ pointer_t get_ptr() const { return _data; }
2730
+
2731
+ private:
2732
+ pointer_t _data;
2733
+ sycl::range<2> _range;
2734
+ };
2735
+
2736
+ namespace detail {
2737
+ /// Device variable with address space of shared, global or constant.
2738
+ template <class T, memory_region Memory, size_t Dimension> class device_memory {
2739
+ public:
2740
+ using accessor_t =
2741
+ typename detail::memory_traits<Memory,
2742
+ T>::template accessor_t<Dimension>;
2743
+ using value_t = typename detail::memory_traits<Memory, T>::value_t;
2744
+ using dpct_accessor_t = dpct::accessor<T, Memory, Dimension>;
2745
+
2746
+ device_memory() : device_memory(sycl::range<Dimension>(1)) {}
2747
+
2748
+ /// Constructor of 1-D array with initializer list
2749
+ device_memory(const sycl::range<Dimension> &in_range,
2750
+ std::initializer_list<value_t> &&init_list)
2751
+ : device_memory(in_range) {
2752
+ assert(init_list.size() <= in_range.size());
2753
+ _host_ptr = (value_t *)std::malloc(_size);
2754
+ std::memset(_host_ptr, 0, _size);
2755
+ std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T));
2756
+ }
2757
+
2758
+ /// Constructor of 2-D array with initializer list
2759
+ template <size_t D = Dimension>
2760
+ device_memory(
2761
+ const typename std::enable_if<D == 2, sycl::range<2>>::type &in_range,
2762
+ std::initializer_list<std::initializer_list<value_t>> &&init_list)
2763
+ : device_memory(in_range) {
2764
+ assert(init_list.size() <= in_range[0]);
2765
+ _host_ptr = (value_t *)std::malloc(_size);
2766
+ std::memset(_host_ptr, 0, _size);
2767
+ auto tmp_data = _host_ptr;
2768
+ for (auto sub_list : init_list) {
2769
+ assert(sub_list.size() <= in_range[1]);
2770
+ std::memcpy(tmp_data, sub_list.begin(),
2771
+ sub_list.size() * sizeof(T));
2772
+ tmp_data += in_range[1];
2773
+ }
2774
+ }
2775
+
2776
+ /// Constructor with range
2777
+ device_memory(const sycl::range<Dimension> &range_in)
2778
+ : _size(range_in.size() * sizeof(T)), _range(range_in),
2779
+ _reference(false), _host_ptr(nullptr), _device_ptr(nullptr) {
2780
+ static_assert(
2781
+ (Memory == global) || (Memory == constant) || (Memory == shared),
2782
+ "device memory region should be global, constant or shared");
2783
+ // Make sure that singleton class mem_mgr and dev_mgr will destruct
2784
+ // later than this.
2785
+ detail::mem_mgr::instance();
2786
+ dev_mgr::instance();
2787
+ }
2788
+
2789
+ /// Constructor with range
2790
+ template <class... Args>
2791
+ device_memory(Args... Arguments)
2792
+ : device_memory(sycl::range<Dimension>(Arguments...)) {}
2793
+
2794
+ ~device_memory() {
2795
+ if (_device_ptr && !_reference)
2796
+ dpct::dpct_free(_device_ptr);
2797
+ if (_host_ptr)
2798
+ std::free(_host_ptr);
2799
+ }
2800
+
2801
+ /// Allocate memory with default queue, and init memory if has initial
2802
+ /// value.
2803
+ void init() { init(dpct::get_default_queue()); }
2804
+ /// Allocate memory with specified queue, and init memory if has initial
2805
+ /// value.
2806
+ void init(sycl::queue &q) {
2807
+ if (_device_ptr)
2808
+ return;
2809
+ if (!_size)
2810
+ return;
2811
+ allocate_device(q);
2812
+ if (_host_ptr)
2813
+ detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size,
2814
+ host_to_device);
2815
+ }
2816
+
2817
+ /// The variable is assigned to a device pointer.
2818
+ void assign(value_t *src, size_t size) {
2819
+ this->~device_memory();
2820
+ new (this) device_memory(src, size);
2821
+ }
2822
+
2823
+ /// Get memory pointer of the memory object, which is virtual pointer when
2824
+ /// usm is not used, and device pointer when usm is used.
2825
+ value_t *get_ptr() { return get_ptr(get_default_queue()); }
2826
+ /// Get memory pointer of the memory object, which is virtual pointer when
2827
+ /// usm is not used, and device pointer when usm is used.
2828
+ value_t *get_ptr(sycl::queue &q) {
2829
+ init(q);
2830
+ return _device_ptr;
2831
+ }
2832
+
2833
+ /// Get the device memory object size in bytes.
2834
+ size_t get_size() { return _size; }
2835
+
2836
+ template <size_t D = Dimension>
2837
+ typename std::enable_if<D == 1, T>::type &operator[](size_t index) {
2838
+ init();
2839
+ return _device_ptr[index];
2840
+ }
2841
+
2842
+ /// Get dpct::accessor with dimension info for the device memory object
2843
+ /// when usm is used and dimension is greater than 1.
2844
+ template <size_t D = Dimension>
2845
+ typename std::enable_if<D != 1, dpct_accessor_t>::type
2846
+ get_access([[maybe_unused]] sycl::handler &cgh) {
2847
+ return dpct_accessor_t((T *)_device_ptr, _range);
2848
+ }
2849
+
2850
+ private:
2851
+ device_memory(value_t *memory_ptr, size_t size)
2852
+ : _size(size), _range(size / sizeof(T)), _reference(true),
2853
+ _device_ptr(memory_ptr) {}
2854
+
2855
+ void allocate_device(sycl::queue &q) {
2856
+ #ifndef DPCT_USM_LEVEL_NONE
2857
+ if (Memory == shared) {
2858
+ _device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(),
2859
+ q.get_context());
2860
+ return;
2861
+ }
2862
+ #ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY
2863
+ if (Memory == constant) {
2864
+ _device_ptr = (value_t *)sycl::malloc_device(
2865
+ _size, q.get_device(), q.get_context(),
2866
+ sycl::ext::oneapi::property::usm::device_read_only());
2867
+ return;
2868
+ }
2869
+ #endif
2870
+ #endif
2871
+ _device_ptr = (value_t *)detail::dpct_malloc(_size, q);
2872
+ }
2873
+
2874
+ size_t _size;
2875
+ sycl::range<Dimension> _range;
2876
+ bool _reference;
2877
+ value_t *_host_ptr;
2878
+ value_t *_device_ptr;
2879
+ };
2880
+ template <class T, memory_region Memory>
2881
+ class device_memory<T, Memory, 0> : public device_memory<T, Memory, 1> {
2882
+ public:
2883
+ using base = device_memory<T, Memory, 1>;
2884
+ using value_t = typename base::value_t;
2885
+ using accessor_t =
2886
+ typename detail::memory_traits<Memory, T>::template accessor_t<0>;
2887
+
2888
+ /// Constructor with initial value.
2889
+ device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {}
2890
+
2891
+ /// Default constructor
2892
+ device_memory() : base(1) {}
2893
+ };
2894
+ } // namespace detail
2895
+
2896
+ template <class T, size_t Dimension>
2897
+ using global_memory = detail::device_memory<T, global, Dimension>;
2898
+ template <class T, size_t Dimension>
2899
+ using constant_memory = detail::device_memory<T, constant, Dimension>;
2900
+ template <class T, size_t Dimension>
2901
+ using shared_memory = detail::device_memory<T, shared, Dimension>;
2902
+
2903
+
2904
+ template <typename T,
2905
+ sycl::access::address_space addressSpace =
2906
+ sycl::access::address_space::global_space,
2907
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2908
+ sycl::memory_scope memoryScope = sycl::memory_scope::device>
2909
+ inline T atomic_fetch_add(T *addr, T operand) {
2910
+ auto atm =
2911
+ sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
2912
+ return atm.fetch_add(operand);
2913
+ }
2914
+
2915
+ template <sycl::access::address_space addressSpace =
2916
+ sycl::access::address_space::global_space,
2917
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2918
+ sycl::memory_scope memoryScope = sycl::memory_scope::device,
2919
+ typename T1, typename T2>
2920
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
2921
+ auto atm =
2922
+ sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
2923
+ return atm.fetch_add(operand);
2924
+ }
2925
+
2926
+ template <typename T, sycl::access::address_space addressSpace =
2927
+ sycl::access::address_space::global_space>
2928
+ inline T atomic_fetch_add(T *addr, T operand,
2929
+ sycl::memory_order memoryOrder) {
2930
+ switch (memoryOrder) {
2931
+ case sycl::memory_order::relaxed:
2932
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
2933
+ sycl::memory_scope::device>(addr, operand);
2934
+ case sycl::memory_order::acq_rel:
2935
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
2936
+ sycl::memory_scope::device>(addr, operand);
2937
+ case sycl::memory_order::seq_cst:
2938
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
2939
+ sycl::memory_scope::device>(addr, operand);
2940
+ default:
2941
+ assert(false && "Invalid memory_order for atomics. Valid memory_order for "
2942
+ "atomics are: sycl::memory_order::relaxed, "
2943
+ "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
2944
+ }
2945
+ }
2946
+
2947
+ template <sycl::access::address_space addressSpace =
2948
+ sycl::access::address_space::global_space,
2949
+ typename T1, typename T2>
2950
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand,
2951
+ sycl::memory_order memoryOrder) {
2952
+ atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
2953
+ }
2954
+
2955
+ } // COPY from DPCT head files
2956
+
2957
+ #endif // GGML_SYCL_DPCT_HELPER_HPP