@fugood/llama.node 0.3.16 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (281) hide show
  1. package/CMakeLists.txt +6 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +44 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +374 -19
  24. package/src/LlamaCompletionWorker.h +31 -10
  25. package/src/LlamaContext.cpp +216 -7
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
  29. package/src/llama.cpp/.github/workflows/build.yml +89 -767
  30. package/src/llama.cpp/.github/workflows/docker.yml +9 -6
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +19 -23
  33. package/src/llama.cpp/CMakeLists.txt +11 -1
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +35 -4
  37. package/src/llama.cpp/common/arg.cpp +844 -121
  38. package/src/llama.cpp/common/arg.h +9 -0
  39. package/src/llama.cpp/common/chat.cpp +129 -107
  40. package/src/llama.cpp/common/chat.h +2 -0
  41. package/src/llama.cpp/common/common.cpp +64 -518
  42. package/src/llama.cpp/common/common.h +35 -45
  43. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  44. package/src/llama.cpp/common/llguidance.cpp +31 -47
  45. package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
  46. package/src/llama.cpp/common/minja/minja.hpp +186 -127
  47. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  48. package/src/llama.cpp/common/regex-partial.h +56 -0
  49. package/src/llama.cpp/common/sampling.cpp +60 -50
  50. package/src/llama.cpp/docs/build.md +122 -7
  51. package/src/llama.cpp/examples/CMakeLists.txt +2 -32
  52. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
  54. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  55. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  56. package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
  57. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  58. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  59. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  60. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  61. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  62. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  64. package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
  65. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  66. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  67. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  68. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  69. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  70. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  71. package/src/llama.cpp/ggml/include/ggml.h +76 -106
  72. package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
  73. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  74. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  75. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  76. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  77. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  78. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  79. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  80. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  81. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  82. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  83. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
  84. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  85. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  86. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  87. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
  89. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  90. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
  93. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
  94. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
  95. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
  96. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  101. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  102. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
  103. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  104. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  105. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  106. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  107. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  108. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  109. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
  110. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  111. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
  112. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  113. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
  115. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
  116. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
  117. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  120. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
  121. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  122. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  123. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  124. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  136. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  137. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  138. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  140. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  141. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
  143. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
  144. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
  145. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
  146. package/src/llama.cpp/ggml/src/ggml.c +170 -265
  147. package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
  148. package/src/llama.cpp/include/llama.h +82 -22
  149. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  150. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  151. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  152. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  153. package/src/llama.cpp/requirements/requirements-all.txt +5 -3
  154. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  155. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  156. package/src/llama.cpp/src/CMakeLists.txt +4 -2
  157. package/src/llama.cpp/src/llama-adapter.cpp +43 -1
  158. package/src/llama.cpp/src/llama-arch.cpp +163 -17
  159. package/src/llama.cpp/src/llama-arch.h +16 -0
  160. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  161. package/src/llama.cpp/src/llama-batch.h +2 -1
  162. package/src/llama.cpp/src/llama-chat.cpp +91 -16
  163. package/src/llama.cpp/src/llama-chat.h +7 -2
  164. package/src/llama.cpp/src/llama-context.cpp +479 -575
  165. package/src/llama.cpp/src/llama-context.h +44 -33
  166. package/src/llama.cpp/src/llama-cparams.h +1 -0
  167. package/src/llama.cpp/src/llama-graph.cpp +209 -157
  168. package/src/llama.cpp/src/llama-graph.h +38 -14
  169. package/src/llama.cpp/src/llama-hparams.h +13 -0
  170. package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
  171. package/src/llama.cpp/src/llama-kv-cache.h +283 -171
  172. package/src/llama.cpp/src/llama-memory.h +12 -2
  173. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  174. package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
  175. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  176. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  177. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  178. package/src/llama.cpp/src/llama-model.cpp +1803 -330
  179. package/src/llama.cpp/src/llama-model.h +21 -2
  180. package/src/llama.cpp/src/llama-quant.cpp +33 -10
  181. package/src/llama.cpp/src/llama-sampling.cpp +25 -7
  182. package/src/llama.cpp/src/llama-vocab.cpp +86 -10
  183. package/src/llama.cpp/src/llama-vocab.h +6 -0
  184. package/src/llama.cpp/src/llama.cpp +15 -1
  185. package/src/llama.cpp/tests/CMakeLists.txt +52 -31
  186. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  187. package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
  188. package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
  189. package/src/llama.cpp/tests/test-chat.cpp +15 -3
  190. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  191. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  192. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  193. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  194. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  195. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  196. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  197. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  198. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  199. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  200. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  201. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  202. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  203. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  204. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
  205. package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
  206. package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
  207. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  208. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
  209. package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
  210. package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
  211. package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
  212. package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
  213. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  214. package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
  215. package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
  216. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  217. package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
  218. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  219. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
  220. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
  221. package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
  222. package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
  223. package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
  224. package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
  225. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  226. package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
  227. package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
  228. package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
  229. package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
  230. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  231. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  232. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  233. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  234. package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
  235. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  236. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  237. package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
  238. package/src/llama.cpp/examples/llava/clip.h +0 -118
  239. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  240. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  241. package/src/llama.cpp/examples/llava/llava.cpp +0 -574
  242. package/src/llama.cpp/examples/llava/llava.h +0 -49
  243. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  244. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
  245. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  246. package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
  247. package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
  248. package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
  249. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  250. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  251. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  252. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  253. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  254. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  255. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  256. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  257. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  258. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  259. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  260. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  261. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  262. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  263. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  264. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  265. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  266. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  267. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  268. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  269. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  270. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  271. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  272. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  273. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  274. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  275. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  276. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  277. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  278. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  279. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  280. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  281. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
@@ -37,6 +37,8 @@
37
37
  #include "ggml-backend-impl.h"
38
38
 
39
39
  #include "ggml-sycl/backend.hpp"
40
+ #include "ggml-sycl/common.hpp"
41
+ #include "ggml-sycl/element_wise.hpp"
40
42
  #include "ggml-sycl/presets.hpp"
41
43
  #include "ggml-sycl/gemm.hpp"
42
44
  #include "ggml-sycl/sycl_hw.hpp"
@@ -47,6 +49,8 @@ static bool g_sycl_loaded = false;
47
49
  int g_ggml_sycl_debug = 0;
48
50
  int g_ggml_sycl_disable_optimize = 0;
49
51
  int g_ggml_sycl_disable_graph = 0;
52
+ int g_ggml_sycl_disable_dnn = 0;
53
+ int g_ggml_sycl_prioritize_dmmv = 0;
50
54
 
51
55
  static ggml_sycl_device_info ggml_sycl_init() {
52
56
  ggml_sycl_device_info info = {};
@@ -191,13 +195,25 @@ static void ggml_check_sycl() try {
191
195
 
192
196
  if (!initialized) {
193
197
  g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
194
- g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
198
+ g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
195
199
  g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
200
+ g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
201
+ g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
196
202
  GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
197
203
  GGML_LOG_INFO("Running with Environment Variables:\n");
198
204
  GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
199
205
  GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
206
+ #ifdef GGML_SYCL_GRAPH
200
207
  GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
208
+ #else
209
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
210
+ #endif
211
+ #if GGML_SYCL_DNNL
212
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
213
+ #else
214
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
215
+ #endif
216
+ GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
201
217
  GGML_LOG_INFO("Build with Macros:\n");
202
218
  #if defined(GGML_SYCL_FORCE_MMQ)
203
219
  GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
@@ -336,7 +352,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
336
352
  assert(tensor->view_src->buffer->buft == buffer->buft);
337
353
  return GGML_STATUS_SUCCESS;
338
354
  }
339
- if (tensor->type == GGML_TYPE_Q4_0) {
355
+ if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
340
356
  ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
341
357
  tensor->extra = extra;
342
358
  ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
@@ -371,6 +387,8 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
371
387
  auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
372
388
  SYCL_CHECK(
373
389
  CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
390
+ // Note: Use host buffer to save the data from mmap(), then copy to device. It's workaround for mmap() issue on PVC GPU.
391
+ // This function will be called during load model from disk. Use memory buffer replace dynamic won't save more time and brings potential memory leak risk here.
374
392
  char* host_buf = (char*)malloc(size);
375
393
  memcpy(host_buf, data, size);
376
394
  SYCL_CHECK(
@@ -490,6 +508,23 @@ catch (sycl::exception const &exc) {
490
508
  std::exit(1);
491
509
  }
492
510
 
511
+ static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
512
+ size_t offset, size_t size) {
513
+ GGML_SYCL_DEBUG(" [SYCL] call %s\n", __func__);
514
+ ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
515
+ SYCL_CHECK(ggml_sycl_set_device(ctx->device));
516
+ auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
517
+ if (size == 0) {
518
+ return; // Nothing to do
519
+ }
520
+ if (tensor->data == nullptr) {
521
+ GGML_ABORT("Error: Tensor data pointer is null.\n");
522
+ }
523
+ void * target_ptr = static_cast<char *>(tensor->data) + offset;
524
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size)));
525
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait()));
526
+ }
527
+
493
528
  static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
494
529
  GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
495
530
  if (buffer == nullptr) {
@@ -510,7 +545,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
510
545
  /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
511
546
  /* .get_base = */ ggml_backend_sycl_buffer_get_base,
512
547
  /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
513
- /* .memset_tensor = */ NULL,
548
+ /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor,
514
549
  /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
515
550
  /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
516
551
  /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
@@ -1597,17 +1632,6 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
1597
1632
  dst[i] = scale * x[i];
1598
1633
  }
1599
1634
 
1600
- static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
1601
- const sycl::nd_item<3> &item_ct1) {
1602
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1603
- item_ct1.get_local_id(2);
1604
-
1605
- if (i >= k) {
1606
- return;
1607
- }
1608
-
1609
- dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
1610
- }
1611
1635
 
1612
1636
  template <typename Ti, typename To>
1613
1637
  static void pool2d_nchw_kernel(
@@ -1748,18 +1772,6 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
1748
1772
  });
1749
1773
  }
1750
1774
 
1751
- static void clamp_f32_sycl(const float *x, float *dst, const float min,
1752
- const float max, const int k,
1753
- queue_ptr stream) {
1754
- const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
1755
- stream->parallel_for(
1756
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
1757
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
1758
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
1759
- [=](sycl::nd_item<3> item_ct1) {
1760
- clamp_f32(x, dst, min, max, k, item_ct1);
1761
- });
1762
- }
1763
1775
 
1764
1776
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
1765
1777
  const int nrows, queue_ptr stream) {
@@ -1970,19 +1982,6 @@ catch (sycl::exception const &exc) {
1970
1982
  std::exit(1);
1971
1983
  }
1972
1984
 
1973
- static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
1974
- const ggml_tensor *src1, ggml_tensor *dst,
1975
- const float *src0_d, const float *src1_d,
1976
- float *dst_d,
1977
- const queue_ptr &main_stream) {
1978
-
1979
- ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
1980
-
1981
- GGML_UNUSED(src1);
1982
- GGML_UNUSED(src1_d);
1983
- }
1984
-
1985
-
1986
1985
  inline void ggml_sycl_op_mul_mat_sycl(
1987
1986
  ggml_backend_sycl_context & ctx,
1988
1987
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -1997,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
1997
1996
 
1998
1997
  const int64_t ne00 = src0->ne[0];
1999
1998
  const int64_t ne10 = src1->ne[0];
2000
-
1999
+ GGML_ASSERT(ne00 == ne10);
2001
2000
 
2002
2001
  const int64_t row_diff = row_high - row_low;
2003
2002
 
2004
2003
  int id;
2005
2004
  SYCL_CHECK(
2006
2005
  CHECK_TRY_ERROR(id = get_current_device_id()));
2007
- #if !GGML_SYCL_DNNL
2008
- const int64_t ne0 = dst->ne[0];
2006
+
2007
+ const int64_t ne0 = dst->ne[0]; // used by MKL only
2009
2008
  // the main device has a larger memory buffer to hold the results from all GPUs
2010
2009
  // ldc == nrows of the matrix that cuBLAS writes into
2011
- int ldc = id == ctx.device ? ne0 : row_diff;
2012
- #endif
2010
+ int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
2013
2011
 
2014
2012
  #ifdef GGML_SYCL_F16
2015
2013
  bool use_fp16 = true; // TODO(Yu) SYCL capability check
@@ -2045,25 +2043,29 @@ inline void ggml_sycl_op_mul_mat_sycl(
2045
2043
  : src1_as_f16.get();
2046
2044
  ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2047
2045
 
2048
- #if !GGML_SYCL_DNNL
2049
- const sycl::half alpha_f16 = 1.0f;
2050
- const sycl::half beta_f16 = 0.0f;
2051
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2052
- *stream, oneapi::mkl::transpose::trans,
2053
- oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2054
- &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2055
- src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2056
- dst_f16.get(), dpct::library_data_t::real_half, ldc,
2057
- dpct::library_data_t::real_half)));
2058
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2059
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2060
- #else
2061
- DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
2062
- DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2063
- dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
2064
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2065
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2046
+ #if GGML_SYCL_DNNL
2047
+ if (!g_ggml_sycl_disable_dnn) {
2048
+ DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
2049
+ DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2050
+ dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
2051
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2052
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2053
+ }
2054
+ else
2066
2055
  #endif
2056
+ {
2057
+ const sycl::half alpha_f16 = 1.0f;
2058
+ const sycl::half beta_f16 = 0.0f;
2059
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2060
+ *stream, oneapi::math::transpose::trans,
2061
+ oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
2062
+ &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2063
+ src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2064
+ dst_f16.get(), dpct::library_data_t::real_half, ldc,
2065
+ dpct::library_data_t::real_half)));
2066
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2067
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2068
+ }
2067
2069
  }
2068
2070
  else {
2069
2071
  // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
@@ -2084,25 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
2084
2086
  const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
2085
2087
  const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
2086
2088
 
2087
- #if !GGML_SYCL_DNNL
2088
- const float alpha = 1.0f;
2089
- const float beta = 0.0f;
2090
- # ifdef GGML_SYCL_NVIDIA
2091
- SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2092
- oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
2093
- oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
2094
- ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2095
- # else
2096
- SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2097
- *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2098
- dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
2099
- dst_dd_i, ldc)));
2100
- # endif
2101
- #else
2102
- DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
2103
- DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2104
- dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2089
+ #if GGML_SYCL_DNNL
2090
+ if (!g_ggml_sycl_disable_dnn) {
2091
+ DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
2092
+ DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2093
+ dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2094
+ }
2095
+ else
2105
2096
  #endif
2097
+ {
2098
+ const float alpha = 1.0f;
2099
+ const float beta = 0.0f;
2100
+ SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
2101
+ get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
2102
+ src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
2103
+ dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2104
+ }
2106
2105
  }
2107
2106
  GGML_UNUSED(dst);
2108
2107
  GGML_UNUSED(src1_ddq_i);
@@ -2114,13 +2113,14 @@ catch (sycl::exception const &exc) {
2114
2113
  std::exit(1);
2115
2114
  }
2116
2115
 
2117
- static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2118
- const ggml_tensor *src1, ggml_tensor *dst,
2119
- const float *src0_dd, const float *src1_dd,
2120
- float *dst_dd, const queue_ptr &main_stream) {
2116
+ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2121
2117
 
2122
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2118
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2123
2119
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2120
+ dpct::queue_ptr main_stream = ctx.stream();
2121
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2122
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2123
+ float * dst_dd = static_cast<float *>(dst->data);
2124
2124
 
2125
2125
  const int32_t * opts = (const int32_t *)dst->op_params;
2126
2126
  enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
@@ -2131,8 +2131,8 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
2131
2131
  const int p0 = opts[5];
2132
2132
  const int p1 = opts[6];
2133
2133
 
2134
- const int64_t IH = src0->ne[1];
2135
- const int64_t IW = src0->ne[0];
2134
+ const int64_t IH = dst->src[0]->ne[1];
2135
+ const int64_t IW = dst->src[0]->ne[0];
2136
2136
 
2137
2137
  const int64_t N = dst->ne[3];
2138
2138
  const int64_t OC = dst->ne[2];
@@ -2151,163 +2151,105 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
2151
2151
  parallel_elements, src0_dd, dst_dd, op,
2152
2152
  item_ct1);
2153
2153
  });
2154
-
2155
- GGML_UNUSED(src1);
2156
- GGML_UNUSED(src1_dd);
2157
- GGML_UNUSED(ctx);
2158
2154
  }
2159
2155
 
2160
- inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2161
- const ggml_tensor *src1, ggml_tensor *dst,
2162
- const float *src0_dd, const float *src1_dd,
2163
- float *dst_dd,
2164
- const queue_ptr &main_stream) {
2165
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2156
+ inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2157
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2166
2158
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2159
+ dpct::queue_ptr main_stream = ctx.stream();
2160
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2161
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2162
+ float * dst_dd = static_cast<float *>(dst->data);
2167
2163
 
2168
- const int64_t ne = ggml_nelements(src0);
2164
+ const int64_t ne = ggml_nelements(dst->src[0]);
2169
2165
 
2170
2166
  sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
2171
-
2172
- GGML_UNUSED(src1);
2173
- GGML_UNUSED(dst);
2174
- GGML_UNUSED(src1_dd);
2175
- GGML_UNUSED(ctx);
2176
2167
  }
2177
2168
 
2178
- inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2179
- const ggml_tensor *src1, ggml_tensor *dst,
2180
- const float *src0_dd, const float *src1_dd,
2181
- float *dst_dd,
2182
- const queue_ptr &main_stream) {
2169
+ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2183
2170
 
2184
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2171
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2185
2172
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2173
+ dpct::queue_ptr main_stream = ctx.stream();
2174
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2175
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2176
+ float * dst_dd = static_cast<float *>(dst->data);
2186
2177
 
2187
- const int64_t ncols = src0->ne[0];
2188
- const int64_t nrows = ggml_nrows(src0);
2178
+ const int64_t ncols = dst->src[0]->ne[0];
2179
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2189
2180
 
2190
2181
  sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2191
-
2192
- GGML_UNUSED(src1);
2193
- GGML_UNUSED(dst);
2194
- GGML_UNUSED(src1_dd);
2195
- GGML_UNUSED(ctx);
2196
2182
  }
2197
2183
 
2198
- inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2199
- const ggml_tensor *src1, ggml_tensor *dst,
2200
- const float *src0_dd, const float *src1_dd,
2201
- float *dst_dd,
2202
- const queue_ptr &main_stream) {
2184
+ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2185
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2186
+ GGML_ASSERT(dst->type == GGML_TYPE_I32);
2187
+ dpct::queue_ptr main_stream = ctx.stream();
2188
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2189
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2190
+ int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2203
2191
 
2204
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2205
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
2206
2192
 
2207
- const int64_t ncols = src0->ne[0];
2208
- const int64_t nrows = ggml_nrows(src0);
2193
+ const int64_t ncols = dst->src[0]->ne[0];
2194
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2209
2195
 
2210
2196
  enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2211
2197
 
2212
- argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
2213
-
2214
- GGML_UNUSED(src1);
2215
- GGML_UNUSED(dst);
2216
- GGML_UNUSED(src1_dd);
2217
- GGML_UNUSED(ctx);
2198
+ argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
2218
2199
  }
2219
2200
 
2220
- inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2221
- const ggml_tensor *src1, ggml_tensor *dst,
2222
- const float *src0_dd, const float *src1_dd,
2223
- float *dst_dd,
2224
- const queue_ptr &main_stream) {
2201
+ inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2225
2202
 
2226
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2203
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2227
2204
  GGML_ASSERT( dst->type == GGML_TYPE_I32);
2228
2205
 
2229
- const int64_t ncols = src0->ne[0];
2230
- const int64_t nrows = ggml_nrows(src0);
2206
+ dpct::queue_ptr main_stream = ctx.stream();
2207
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2208
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2209
+ int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2231
2210
 
2232
- argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream);
2211
+ const int64_t ncols = dst->src[0]->ne[0];
2212
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2233
2213
 
2234
- GGML_UNUSED(src1);
2235
- GGML_UNUSED(dst);
2236
- GGML_UNUSED(src1_dd);
2237
- GGML_UNUSED(ctx);
2214
+ argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2238
2215
  }
2239
2216
 
2240
- inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2241
- const ggml_tensor *src1,
2242
- ggml_tensor *dst, const float *src0_dd,
2243
- const float *src1_dd, float *dst_dd,
2244
- const queue_ptr &main_stream) {
2217
+ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx,ggml_tensor *dst) {
2245
2218
 
2246
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2219
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2247
2220
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2221
+ dpct::queue_ptr main_stream = ctx.stream();
2222
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2223
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2224
+ float * dst_dd = static_cast<float *>(dst->data);
2248
2225
 
2249
- const int64_t ne00 = src0->ne[0];
2250
- const int64_t ne01 = src0->ne[1];
2251
- const int nrows0 = ggml_nrows(src0);
2226
+ const int64_t ne00 = dst->src[0]->ne[0];
2227
+ const int64_t ne01 = dst->src[0]->ne[1];
2228
+ const int nrows0 = ggml_nrows(dst->src[0]);
2252
2229
 
2253
2230
  const int n_past = ((int32_t *) dst->op_params)[0];
2254
2231
 
2255
2232
  diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
2256
-
2257
- GGML_UNUSED(src1);
2258
- GGML_UNUSED(dst);
2259
- GGML_UNUSED(src1_dd);
2260
- GGML_UNUSED(ctx);
2261
2233
  }
2262
2234
 
2263
- inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
2264
- ggml_tensor *dst, const float *src0_dd,
2265
- const float *src1_dd, float *dst_dd,
2266
- const queue_ptr &main_stream) {
2235
+ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2267
2236
 
2268
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2237
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2269
2238
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2239
+ dpct::queue_ptr main_stream = ctx.stream();
2240
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2241
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2242
+ float * dst_dd = static_cast<float *>(dst->data);
2270
2243
 
2271
2244
  float scale;
2272
2245
  memcpy(&scale, dst->op_params, sizeof(float));
2273
2246
 
2274
- scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
2247
+ scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
2275
2248
  /*
2276
2249
  DPCT1010:87: SYCL uses exceptions to report errors and does not use the
2277
2250
  error codes. The call was replaced with 0. You need to rewrite this code.
2278
2251
  */
2279
2252
  SYCL_CHECK(0);
2280
-
2281
- GGML_UNUSED(src1);
2282
- GGML_UNUSED(dst);
2283
- GGML_UNUSED(src1_dd);
2284
- GGML_UNUSED(ctx);
2285
- }
2286
-
2287
- inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
2288
- ggml_tensor *dst, const float *src0_dd,
2289
- const float *src1_dd, float *dst_dd,
2290
- const queue_ptr &main_stream) {
2291
-
2292
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2293
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
2294
-
2295
- float min;
2296
- float max;
2297
- memcpy(&min, dst->op_params, sizeof(float));
2298
- memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
2299
-
2300
- clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
2301
- /*
2302
- DPCT1010:88: SYCL uses exceptions to report errors and does not use the
2303
- error codes. The call was replaced with 0. You need to rewrite this code.
2304
- */
2305
- SYCL_CHECK(0);
2306
-
2307
- GGML_UNUSED(src1);
2308
- GGML_UNUSED(dst);
2309
- GGML_UNUSED(src1_dd);
2310
- GGML_UNUSED(ctx);
2311
2253
  }
2312
2254
 
2313
2255
  static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
@@ -2675,39 +2617,33 @@ catch (sycl::exception const &exc) {
2675
2617
  }
2676
2618
 
2677
2619
 
2678
- static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2679
- GGML_SYCL_DEBUG("call %s\n", __func__);
2680
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_repeat);
2681
- GGML_SYCL_DEBUG("call %s done\n", __func__);
2682
- }
2683
-
2684
2620
  static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2685
2621
  GGML_SYCL_DEBUG("call %s\n", __func__);
2686
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_get_rows);
2622
+ ggml_sycl_op_get_rows(ctx, dst);
2687
2623
  GGML_SYCL_DEBUG("call %s done\n", __func__);
2688
2624
  }
2689
2625
 
2690
2626
  static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2691
2627
  GGML_SYCL_DEBUG("call %s\n", __func__);
2692
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_norm);
2628
+ ggml_sycl_op_norm(ctx, dst);
2693
2629
  GGML_SYCL_DEBUG("call %s done\n", __func__);
2694
2630
  }
2695
2631
 
2696
2632
  static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2697
2633
  GGML_SYCL_DEBUG("call %s\n", __func__);
2698
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rms_norm);
2634
+ ggml_sycl_op_rms_norm(ctx, dst);
2699
2635
  GGML_SYCL_DEBUG("call %s done\n", __func__);
2700
2636
  }
2701
2637
 
2702
2638
  static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2703
2639
  GGML_SYCL_DEBUG("call %s\n", __func__);
2704
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
2640
+ ggml_sycl_op_l2_norm(ctx, dst);
2705
2641
  GGML_SYCL_DEBUG("call %s done\n", __func__);
2706
2642
  }
2707
2643
 
2708
2644
  static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2709
2645
  GGML_SYCL_DEBUG("call %s\n", __func__);
2710
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
2646
+ ggml_sycl_op_group_norm(ctx, dst);
2711
2647
  GGML_SYCL_DEBUG("call %s done\n", __func__);
2712
2648
  }
2713
2649
 
@@ -2779,143 +2715,180 @@ catch (sycl::exception const &exc) {
2779
2715
  std::exit(1);
2780
2716
  }
2781
2717
 
2782
- static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
2783
- const sycl::half *src1_as_f16, char *dst,
2784
- const void **ptrs_src, void **ptrs_dst,
2785
- int64_t ne12, int64_t ne13, int64_t ne23,
2786
- size_t nb02, size_t nb03, size_t nb12,
2787
- size_t nb13, size_t nbd2, size_t nbd3,
2788
- int64_t r2, int64_t r3,
2789
- const sycl::nd_item<3> &item_ct1) {
2790
- int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
2791
- item_ct1.get_local_id(2);
2792
- int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
2793
- item_ct1.get_local_id(1);
2718
+ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
2719
+ const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
2720
+ size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
2721
+ int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
2722
+ const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2);
2723
+ const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
2794
2724
 
2795
2725
  if (i13 >= ne13 || i12 >= ne12) {
2796
2726
  return;
2797
2727
  }
2798
2728
 
2799
- int64_t i03 = i13 / r3;
2800
- int64_t i02 = i12 / r2;
2729
+ const int64_t i03 = i13 / r3;
2730
+ const int64_t i02 = i12 / r2;
2801
2731
 
2802
- ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
2803
- ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
2804
- ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
2732
+ const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
2733
+ const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
2734
+ uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
2735
+
2736
+ ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
2737
+ ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
2738
+ ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3;
2805
2739
  }
2806
2740
 
2807
- static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
2808
- const ggml_tensor *src0,
2809
- const ggml_tensor *src1,
2810
- ggml_tensor *dst) try {
2741
+ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0,
2742
+ const ggml_tensor * src1, ggml_tensor * dst) try {
2811
2743
  GGML_ASSERT(!ggml_is_transposed(src0));
2812
2744
  GGML_ASSERT(!ggml_is_transposed(src1));
2813
2745
  GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
2814
2746
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2747
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
2815
2748
 
2816
2749
  GGML_TENSOR_BINARY_OP_LOCALS
2817
2750
 
2751
+ // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155
2752
+ // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst
2753
+ GGML_ASSERT(ggml_is_contiguous(dst));
2818
2754
 
2819
2755
  SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2820
- queue_ptr main_stream = ctx.stream();;
2756
+ queue_ptr queue = ctx.stream();
2821
2757
 
2822
- void * src0_ddq = src0->data;
2823
- sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
2824
- float * src1_ddf = (float *) src1->data;
2825
- float * dst_ddf = (float *) dst->data;
2758
+ dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 });
2826
2759
 
2827
- // convert src1 to fp16
2760
+ const sycl::half * src0_f16 = static_cast<const sycl::half *>(src0->data);
2761
+ float * dst_ddf = static_cast<float *>(dst->data);
2762
+
2763
+ const sycl::half * src1_f16 = static_cast<const sycl::half *>(src1->data);
2764
+ const size_t type_size_src1 = ggml_type_size(src1->type);
2765
+ GGML_ASSERT(nb10 == type_size_src1);
2766
+
2767
+ // SRC1 strides
2768
+ int64_t s11 = nb11 / type_size_src1;
2769
+ int64_t s12 = nb12 / type_size_src1;
2770
+ int64_t s13 = nb13 / type_size_src1;
2828
2771
  ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
2772
+
2773
+ // convert src1 to fp16
2829
2774
  if (src1->type != GGML_TYPE_F16) {
2830
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
2775
+ const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
2776
+ GGML_ASSERT(to_fp16_nc_sycl != nullptr);
2831
2777
  const int64_t ne_src1 = ggml_nelements(src1);
2832
2778
  src1_f16_alloc.alloc(ne_src1);
2833
- GGML_ASSERT(to_fp16_sycl != nullptr);
2834
- to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
2779
+ to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
2780
+
2781
+ src1_f16 = src1_f16_alloc.get();
2782
+ s11 = ne10;
2783
+ s12 = ne11 * s11;
2784
+ s13 = ne12 * s12;
2835
2785
  }
2836
- sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
2837
- : src1_f16_alloc.get();
2838
2786
 
2839
- char * dst_t;
2787
+ ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
2840
2788
 
2841
- dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
2842
- dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
2789
+ dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
2790
+ dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
2843
2791
 
2844
2792
  // dst strides
2845
2793
  size_t nbd2 = dst->nb[2];
2846
2794
  size_t nbd3 = dst->nb[3];
2847
2795
 
2848
2796
  const float alpha_f32 = 1.0f;
2849
- const float beta_f32 = 0.0f;
2797
+ const float beta_f32 = 0.0f;
2850
2798
 
2851
2799
  const void * alpha = &alpha_f32;
2852
2800
  const void * beta = &beta_f32;
2853
2801
 
2854
- dst_t = (char *) dst_ddf;
2855
-
2856
2802
  GGML_ASSERT(ne12 % ne02 == 0);
2857
2803
  GGML_ASSERT(ne13 % ne03 == 0);
2804
+ GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
2805
+ GGML_ASSERT(ne10 == ne00);
2858
2806
 
2859
2807
  // broadcast factors
2860
- const int64_t r2 = ne12/ne02;
2861
- const int64_t r3 = ne13/ne03;
2862
-
2863
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2864
- // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2865
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2866
- *main_stream, oneapi::mkl::transpose::trans,
2867
- oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
2868
- (const char *)src0_as_f16, dpct::library_data_t::real_half,
2869
- nb01 / nb00, nb02 / nb00,
2870
- (const char *)src1_f16, dpct::library_data_t::real_half,
2871
- nb11 / nb10, nb12 / nb10, beta,
2872
- (char *)dst_t, cu_data_type, ne01, nb2 / nb0,
2873
- ne12 * ne13, cu_compute_type)));
2874
- } else {
2875
- const int ne23 = ne12*ne13;
2876
-
2877
- ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
2878
- ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
2879
- ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
2880
-
2881
- sycl::range<3> block_dims(1, ne12, ne13);
2882
- /*
2883
- DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
2884
- the limit. To get the device limit, query
2885
- info::device::max_work_group_size. Adjust the work-group size if needed.
2886
- */
2887
- {
2888
- dpct::has_capability_or_fail(main_stream->get_device(),
2889
- {sycl::aspect::fp16});
2890
-
2891
- main_stream->submit([&](sycl::handler &cgh) {
2892
- const void **ptrs_src_get = ptrs_src.get();
2893
- void **ptrs_dst_get = ptrs_dst.get();
2894
- size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
2895
- size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
2896
- cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
2897
- [=](sycl::nd_item<3> item_ct1) {
2898
- k_compute_batched_ptrs(
2899
- src0_as_f16, src1_f16,
2900
- dst_t, ptrs_src_get,
2901
- ptrs_dst_get, ne12, ne13, ne23,
2902
- nb02, nb03, nb12_scaled, nb13_scaled,
2903
- nbd2, nbd3, r2, r3, item_ct1);
2904
- });
2808
+ const int64_t r2 = ne12 / ne02;
2809
+ const int64_t r3 = ne13 / ne03;
2810
+
2811
+ #if GGML_SYCL_DNNL
2812
+ if (!g_ggml_sycl_disable_dnn) {
2813
+ auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
2814
+ (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
2815
+
2816
+ DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
2817
+ src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
2818
+ src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
2819
+ dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
2820
+ };
2821
+
2822
+ if (r2 == 1 && r3 == 1) {
2823
+ if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2824
+ dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
2825
+ }
2826
+ else {
2827
+ for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2828
+ const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
2829
+ const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
2830
+ float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
2831
+ dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
2832
+ }
2833
+ }
2834
+ } else {
2835
+ // iterate over batches from smaller set of matrices (matrix 0)
2836
+ for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
2837
+ for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2838
+ const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
2839
+ const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
2840
+ float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
2841
+ dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
2842
+ }
2843
+ }
2844
+ }
2845
+ }
2846
+ else
2847
+ #endif
2848
+ {
2849
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2850
+ // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2851
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
2852
+ oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2853
+ src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2854
+ src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
2855
+ mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
2856
+ } else {
2857
+ const int ne23 = ne12 * ne13;
2858
+
2859
+ ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
2860
+ ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
2861
+ ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
2862
+
2863
+ sycl::range<3> block_dims(1, ne12, ne13);
2864
+ queue->submit([&](sycl::handler & cgh) {
2865
+ const void ** ptrs_src_get = ptrs_src.get();
2866
+ void ** ptrs_dst_get = ptrs_dst.get();
2867
+ size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
2868
+ size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
2869
+ cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2870
+ k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
2871
+ nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
2872
+ });
2905
2873
  });
2874
+
2875
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2876
+ *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2877
+ (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
2878
+ (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
2879
+ (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
2906
2880
  }
2907
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2908
- *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
2909
- (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
2910
- (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
2911
- (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
2912
2881
  }
2882
+ } catch (const sycl::exception & exc) {
2883
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
2884
+ std::exit(1);
2913
2885
  }
2914
- catch (sycl::exception const &exc) {
2915
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
2916
- << ", line:" << __LINE__ << std::endl;
2917
- std::exit(1);
2918
- }
2886
+
2887
+ enum class mul_mat_algo {
2888
+ DMMV = 0,
2889
+ MMVQ = 1,
2890
+ MUL_MAT_SYCL = 2,
2891
+ };
2919
2892
 
2920
2893
  inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
2921
2894
  // TODO: accuracy issues in MMQ
@@ -2923,6 +2896,36 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
2923
2896
  return false;
2924
2897
  }
2925
2898
 
2899
+ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
2900
+ switch (type) {
2901
+ case GGML_TYPE_Q4_0:
2902
+ return true;
2903
+ case GGML_TYPE_Q4_K:
2904
+ return !g_ggml_sycl_prioritize_dmmv;
2905
+ default:
2906
+ return false;
2907
+ }
2908
+ }
2909
+
2910
+ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
2911
+ switch (type) {
2912
+ case GGML_TYPE_Q4_0:
2913
+ return true;
2914
+ default:
2915
+ return false;
2916
+ }
2917
+ }
2918
+
2919
+ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
2920
+ switch (type) {
2921
+ case GGML_TYPE_Q4_0:
2922
+ case GGML_TYPE_Q4_K:
2923
+ return true;
2924
+ default:
2925
+ return false;
2926
+ }
2927
+ }
2928
+
2926
2929
  static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2927
2930
  switch (type) {
2928
2931
  case GGML_TYPE_Q4_0:
@@ -2942,13 +2945,142 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2942
2945
  }
2943
2946
  }
2944
2947
 
2945
- static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2948
+ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
2949
+ dpct::queue_ptr stream) {
2950
+ auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
2951
+ SYCL_CHECK(
2952
+ CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
2953
+ .wait()));
2954
+ GGML_ASSERT((size % sizeof(block_q4_0) == 0));
2955
+ GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
2956
+ int offset_blks = offset / sizeof(block_q4_0);
2957
+ auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
2958
+ auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
2959
+
2960
+ stream->parallel_for(
2961
+ size / sizeof(block_q4_0),
2962
+ [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
2963
+ const block_q4_0* x = (const block_q4_0*)tmp_buf;
2964
+ const int ib = i;
2965
+
2966
+ for (int j = 0; j < QK4_0/2; j ++)
2967
+ {
2968
+ *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
2969
+ }
2970
+ *(d_ptr + ib) = x[ib].d;
2971
+ }).wait_and_throw();
2972
+
2973
+ sycl::free(tmp_buf, *stream);
2974
+ }
2975
+
2976
+ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
2977
+ GGML_ASSERT(size % sizeof(block_q4_K) == 0);
2978
+ GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
2979
+
2980
+ const int nblocks = size / sizeof(block_q4_K);
2981
+
2982
+ auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
2983
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
2984
+
2985
+ auto * qs_ptr = data_device;
2986
+ auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
2987
+ auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
2988
+
2989
+ stream->parallel_for(nblocks, [=](auto i) {
2990
+ const block_q4_K * x = (const block_q4_K *) tmp_buf;
2991
+ const int ib = i;
2992
+
2993
+ for (int j = 0; j < QK_K / 2; ++j) {
2994
+ qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
2995
+ }
2996
+
2997
+ for (int j = 0; j < K_SCALE_SIZE; ++j) {
2998
+ scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
2999
+ }
3000
+
3001
+ dm_ptr[ib] = x[ib].dm;
3002
+ }).wait_and_throw();
3003
+
3004
+ sycl::free(tmp_buf, *stream);
3005
+ }
3006
+
3007
+ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
3008
+ uint8_t * data_device = (uint8_t *) src0->data;
3009
+ size_t ncols = src0->ne[0];
3010
+ size_t nrows = src0->ne[1];
3011
+ size_t size = ggml_nbytes(src0);
3012
+
3013
+ switch (src0->type) {
3014
+ case GGML_TYPE_Q4_0:
3015
+ reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
3016
+ break;
3017
+ case GGML_TYPE_Q4_K:
3018
+ reorder_qw_q4_k(data_device, size, 0, stream);
3019
+ break;
3020
+ default:
3021
+ GGML_ABORT("reorder_qw() called with unsupported type");
3022
+ break;
3023
+ }
3024
+ }
3025
+
3026
+ static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
3027
+ return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
3028
+ ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
3029
+ dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
3030
+ dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
3031
+ }
3032
+
3033
+ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
3034
+ ggml_tensor * dst, mul_mat_algo mm_algorithm) {
3035
+ if (!should_reorder_tensor(*ctx, dst)) {
3036
+ return;
3037
+ }
3038
+
3039
+ ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
3040
+ if (!extra || extra->optimized_feature.reorder) {
3041
+ return; // Skip permutations and already reordered tensors
3042
+ }
3043
+
3044
+ switch (mm_algorithm) {
3045
+ case mul_mat_algo::DMMV:
3046
+ if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
3047
+ return;
3048
+ }
3049
+ break;
3050
+ case mul_mat_algo::MMVQ:
3051
+ if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
3052
+ return;
3053
+ }
3054
+ break;
3055
+ case mul_mat_algo::MUL_MAT_SYCL:
3056
+ if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
3057
+ return;
3058
+ }
3059
+ break;
3060
+ }
3061
+
3062
+ reorder_qw(src0, ctx->stream());
3063
+ extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
3064
+ }
2946
3065
 
3066
+
3067
+ static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3068
+ return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3069
+ src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3070
+ }
3071
+
3072
+ static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3073
+ return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
3074
+ src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3075
+ }
3076
+
3077
+ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2947
3078
  const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
2948
3079
  int64_t min_compute_capability = INT_MAX;
2949
3080
 
2950
3081
  if (split) {
2951
- ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
3082
+ ggml_backend_sycl_split_buffer_type_context * buft_ctx =
3083
+ (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
2952
3084
  auto & tensor_split = buft_ctx->tensor_split;
2953
3085
  for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
2954
3086
  // skip devices that are not going to do any work:
@@ -2961,17 +3093,13 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2961
3093
  }
2962
3094
  }
2963
3095
  } else {
2964
- min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
3096
+ min_compute_capability = ggml_sycl_info().devices[ctx.device].cc;
2965
3097
  }
2966
3098
 
2967
3099
  // check data types and tensor shapes for custom matrix multiplication kernels:
2968
- bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
2969
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2970
- && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
3100
+ bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
2971
3101
 
2972
- bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
2973
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2974
- && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
3102
+ bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
2975
3103
 
2976
3104
  bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
2977
3105
  && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
@@ -2983,9 +3111,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2983
3111
  use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
2984
3112
  #endif // SYCL_USE_XMX
2985
3113
 
3114
+
2986
3115
  // mmvq path is faster in the CUDA backend.
2987
- if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
3116
+ if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
3117
+ // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
3118
+ // is enabled takes precedence over DMMV, the current if-else implementation
3119
+ // requires disabling DMMV if both conditions are met
3120
+ || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
2988
3121
  use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
3122
+ }
2989
3123
 
2990
3124
  if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
2991
3125
  // TODO: Refactor and cleanup of mul mat dispatching.
@@ -2997,22 +3131,30 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
2997
3131
  // The kernel from the if path is faster for that specific case, but does not support all mul mats.
2998
3132
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
2999
3133
  }
3000
- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
3134
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
3001
3135
  // KQV single-batch
3002
3136
  ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
3003
3137
  } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
3004
3138
  // KQ + KQV multi-batch
3005
3139
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3006
3140
  } else if (use_dequantize_mul_mat_vec) {
3007
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
3008
- // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
3141
+ constexpr bool convert_src1_to_q8_1 = false;
3142
+ opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
3143
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
3009
3144
  } else if (use_mul_mat_vec_q) {
3010
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
3145
+ constexpr bool convert_src1_to_q8_1 = true;
3146
+ opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
3147
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
3011
3148
  } else if (use_mul_mat_q) {
3012
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
3149
+ constexpr bool convert_src1_to_q8_1 = true;
3150
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
3013
3151
  } else {
3014
- ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
3152
+ constexpr bool convert_src1_to_q8_1 = false;
3153
+ // MUL_MAT_SYCL supports reorder
3154
+ opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MUL_MAT_SYCL);
3155
+ ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
3015
3156
  }
3157
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
3016
3158
  }
3017
3159
 
3018
3160
 
@@ -3251,48 +3393,39 @@ catch (sycl::exception const &exc) {
3251
3393
  }
3252
3394
 
3253
3395
  static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3254
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_scale);
3255
- }
3256
-
3257
- static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3258
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp);
3396
+ ggml_sycl_op_scale(ctx, dst);
3259
3397
  }
3260
3398
 
3261
3399
  static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3262
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
3263
- }
3264
-
3265
- static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3266
- GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
3267
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
3400
+ ggml_sycl_op_diag_mask_inf(ctx, dst);
3268
3401
  }
3269
3402
 
3270
3403
  static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3271
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pool2d);
3404
+ ggml_sycl_op_pool2d(ctx, dst);
3272
3405
  }
3273
3406
 
3274
3407
  static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3275
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_im2col);
3408
+ ggml_sycl_op_im2col(ctx, dst);
3276
3409
  }
3277
3410
 
3278
3411
  static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3279
3412
  GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3280
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum);
3413
+ ggml_sycl_op_sum(ctx, dst);
3281
3414
  }
3282
3415
 
3283
3416
  static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3284
3417
  GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3285
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum_rows);
3418
+ ggml_sycl_op_sum_rows(ctx, dst);
3286
3419
  }
3287
3420
 
3288
3421
  static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3289
3422
  GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3290
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argsort);
3423
+ ggml_sycl_op_argsort(ctx, dst);
3291
3424
  }
3292
3425
 
3293
3426
  static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3294
3427
  GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3295
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argmax);
3428
+ ggml_sycl_op_argmax(ctx, dst);
3296
3429
  }
3297
3430
 
3298
3431
 
@@ -3317,7 +3450,7 @@ catch (sycl::exception const &exc) {
3317
3450
  std::exit(1);
3318
3451
  }
3319
3452
 
3320
- static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
3453
+ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try {
3321
3454
  if (!g_sycl_loaded) return false;
3322
3455
 
3323
3456
  if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
@@ -3394,6 +3527,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3394
3527
  case GGML_UNARY_OP_EXP:
3395
3528
  ggml_sycl_exp(ctx, dst);
3396
3529
  break;
3530
+ case GGML_UNARY_OP_SGN:
3531
+ ggml_sycl_sgn(ctx, dst);
3532
+ break;
3533
+ case GGML_UNARY_OP_ABS:
3534
+ ggml_sycl_abs(ctx, dst);
3535
+ break;
3536
+ case GGML_UNARY_OP_ELU:
3537
+ ggml_sycl_elu(ctx, dst);
3538
+ break;
3397
3539
  default:
3398
3540
  return false;
3399
3541
  }
@@ -3510,6 +3652,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3510
3652
  }
3511
3653
 
3512
3654
  return true;
3655
+ } catch (sycl::exception & e) {
3656
+ std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
3657
+ std::exit(1);
3513
3658
  }
3514
3659
 
3515
3660
  GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,
@@ -3641,71 +3786,8 @@ catch (sycl::exception const &exc) {
3641
3786
  std::exit(1);
3642
3787
  }
3643
3788
 
3644
- static void reorder_qw(char *data_device, const int ncols, const int nrows,
3645
- size_t size, size_t offset, dpct::queue_ptr stream) {
3646
- auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
3647
- SYCL_CHECK(
3648
- CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
3649
- .wait()));
3650
- GGML_ASSERT((size % sizeof(block_q4_0) == 0));
3651
- GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
3652
- int offset_blks = offset / sizeof(block_q4_0);
3653
- auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;;
3654
- auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
3655
-
3656
- stream->parallel_for(
3657
- size / sizeof(block_q4_0),
3658
- [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
3659
- const block_q4_0* x = (const block_q4_0*)tmp_buf;
3660
- const int ib = i;
3661
-
3662
- for (int j = 0; j < QK4_0/2; j ++)
3663
- {
3664
- *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
3665
- }
3666
- *(d_ptr + ib) = x[ib].d;
3667
- });
3668
-
3669
- sycl::free(tmp_buf, *stream);
3670
- }
3671
-
3672
- static void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
3673
- char*data_device = (char*)src0->data;
3674
- size_t ncols = src0->ne[0];
3675
- size_t nrows = src0->ne[1];
3676
- size_t size = ggml_nbytes(src0);
3677
-
3678
- reorder_qw(data_device, ncols, nrows, size, 0, stream);
3679
- }
3680
-
3681
- static void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
3682
- ggml_tensor *src0 = dst->src[0];
3683
- ggml_tensor *src1 = dst->src[1];
3684
-
3685
- if (dst->op == GGML_OP_MUL_MAT && src0->type == GGML_TYPE_Q4_0 &&
3686
- src1->ne[2]==1 && src1->ne[3]==1) {
3687
- reorder_qw(src0, stream);
3688
- ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
3689
- GGML_ASSERT(extra);
3690
- extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.
3691
- }
3692
- }
3693
-
3694
- static void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
3695
- dpct::queue_ptr stream = ctx->stream();
3696
- if (ctx->optimized_graph) {
3697
- return;
3698
- }
3699
- ctx->optimized_graph = true;
3700
-
3701
- for (int i = 0; i < cgraph->n_nodes; i++) {
3702
- if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream);
3703
- }
3704
- }
3705
-
3706
3789
  static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
3707
3790
  ggml_sycl_set_main_device(sycl_ctx->device);
3708
- if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx);
3709
3791
 
3710
3792
  for (int i = 0; i < cgraph->n_nodes; i++) {
3711
3793
  ggml_tensor * node = cgraph->nodes[i];
@@ -3733,19 +3815,23 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
3733
3815
 
3734
3816
  #ifdef GGML_SYCL_GRAPH
3735
3817
  if (!g_ggml_sycl_disable_graph) {
3736
- if (!sycl_ctx->exec_graph && !dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph)) {
3818
+ const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);
3819
+ if (!graph_support) {
3737
3820
  GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
3738
3821
  ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
3739
3822
  return GGML_STATUS_SUCCESS;
3740
3823
  }
3741
3824
 
3742
- sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
3825
+ sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
3826
+
3743
3827
  model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
3744
3828
  ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
3745
3829
  model_sycl_graph.end_recording();
3746
3830
 
3747
- if (!sycl_ctx->exec_graph) {
3748
- auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
3831
+ const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph);
3832
+ if (!sycl_ctx->exec_graph || !graph_update_support) {
3833
+ auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) :
3834
+ model_sycl_graph.finalize();
3749
3835
  sycl_ctx->exec_graph = std::make_unique<
3750
3836
  sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
3751
3837
  } else {
@@ -3933,7 +4019,14 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
3933
4019
  case GGML_UNARY_OP_GELU_QUICK:
3934
4020
  case GGML_UNARY_OP_TANH:
3935
4021
  case GGML_UNARY_OP_EXP:
3936
- return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32);
4022
+ case GGML_UNARY_OP_SGN:
4023
+ case GGML_UNARY_OP_ABS:
4024
+ case GGML_UNARY_OP_ELU:
4025
+ #if defined (GGML_SYCL_F16)
4026
+ return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
4027
+ #else
4028
+ return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
4029
+ #endif
3937
4030
  default:
3938
4031
  return false;
3939
4032
  }
@@ -4045,7 +4138,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4045
4138
  case GGML_OP_ARGMAX:
4046
4139
  case GGML_OP_NONE:
4047
4140
  case GGML_OP_RESHAPE:
4048
- case GGML_OP_REPEAT:
4049
4141
  case GGML_OP_VIEW:
4050
4142
  case GGML_OP_PERMUTE:
4051
4143
  case GGML_OP_TRANSPOSE:
@@ -4055,13 +4147,19 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4055
4147
  case GGML_OP_SUB:
4056
4148
  case GGML_OP_MUL:
4057
4149
  case GGML_OP_DIV:
4150
+ case GGML_OP_REPEAT:
4151
+ return true;
4058
4152
  case GGML_OP_SQR:
4059
4153
  case GGML_OP_SQRT:
4060
4154
  case GGML_OP_SIN:
4061
4155
  case GGML_OP_COS:
4062
4156
  case GGML_OP_CLAMP:
4063
4157
  case GGML_OP_LOG:
4064
- return (op->src[0]->type == GGML_TYPE_F32);
4158
+ #if defined (GGML_SYCL_F16)
4159
+ return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));
4160
+ #else
4161
+ return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
4162
+ #endif
4065
4163
  case GGML_OP_NORM:
4066
4164
  case GGML_OP_RMS_NORM:
4067
4165
  case GGML_OP_L2_NORM:
@@ -4077,23 +4175,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4077
4175
  case GGML_OP_ROPE:
4078
4176
  {
4079
4177
  const int mode = ((const int32_t *) op->op_params)[2];
4080
- if (mode & GGML_ROPE_TYPE_MROPE) {
4081
- return false;
4082
- }
4083
- if (mode & GGML_ROPE_TYPE_VISION) {
4178
+ // mode is not used as a bitmask in practice, the various rope type modes are independent implementations
4179
+ if (mode == GGML_ROPE_TYPE_MROPE) {
4084
4180
  return false;
4085
4181
  }
4086
- return ggml_is_contiguous(op->src[0]);
4182
+ return true;
4087
4183
  }
4088
4184
  case GGML_OP_IM2COL:
4089
- // TODO: add support for the new F32 operations
4090
- return op->src[0]->type == GGML_TYPE_F16;
4185
+ return true;
4186
+ case GGML_OP_UPSCALE:
4187
+ return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
4091
4188
  case GGML_OP_POOL_2D:
4092
4189
  case GGML_OP_SUM:
4093
4190
  case GGML_OP_SUM_ROWS:
4094
4191
  case GGML_OP_ARGSORT:
4095
4192
  case GGML_OP_ACC:
4096
- case GGML_OP_UPSCALE:
4097
4193
  case GGML_OP_PAD:
4098
4194
  case GGML_OP_LEAKY_RELU:
4099
4195
  case GGML_OP_TIMESTEP_EMBEDDING: