@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -0,0 +1,81 @@
1
+ if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$")
2
+ message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD")
3
+ endif()
4
+
5
+ check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL)
6
+
7
+ if (DEFINED ENV{ONEAPI_ROOT})
8
+ message(STATUS "Using oneAPI Release SYCL compiler (icpx).")
9
+ elseif(SUPPORTS_SYCL)
10
+ message(WARNING "Using open-source SYCL compiler (clang++). Didn't detect ENV {ONEAPI_ROOT}.
11
+ If you expected the oneAPI Release compiler, please install oneAPI & source it, like:
12
+ source /opt/intel/oneapi/setvars.sh")
13
+ else()
14
+ message(FATAL_ERROR, "C++ compiler lacks SYCL support.")
15
+ endif()
16
+ message(STATUS "SYCL found")
17
+ #todo: AOT
18
+
19
+ add_library(ggml-sycl
20
+ ggml-sycl.cpp
21
+ ../../include/ggml-sycl.h)
22
+
23
+ target_link_libraries(ggml-sycl PRIVATE ggml-base)
24
+ target_include_directories(ggml-sycl PRIVATE . ..)
25
+
26
+ if (GGML_SYCL_F16)
27
+ if (GGML_SYCL_TARGET STREQUAL "AMD")
28
+ message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
29
+ endif()
30
+ add_compile_definitions(GGML_SYCL_F16)
31
+ endif()
32
+
33
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl")
34
+
35
+ if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
36
+ add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
37
+ elseif (GGML_SYCL_TARGET STREQUAL "AMD")
38
+ # INFO: Allowed Sub_group_sizes are not consistent through all
39
+ # hip targets. For example, 64 is used for certain models, but the backend
40
+ # does not support it.
41
+ # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32)
42
+ add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
43
+ else()
44
+ add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
45
+ endif()
46
+
47
+ file(GLOB GGML_HEADERS_SYCL "*.hpp")
48
+ file(GLOB GGML_SOURCES_SYCL "*.cpp")
49
+ target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
50
+
51
+ find_package(DNNL)
52
+ message("-- DNNL found:" ${DNNL_FOUND})
53
+
54
+ if (GGML_SYCL_TARGET STREQUAL "INTEL")
55
+ add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
56
+ else()
57
+ add_compile_definitions(GGML_SYCL_DNNL=0)
58
+ endif()
59
+
60
+ if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
61
+ target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
62
+ endif()
63
+
64
+ if (WIN32)
65
+ find_package(IntelSYCL REQUIRED)
66
+ find_package(MKL REQUIRED)
67
+ target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
68
+ else()
69
+ if (GGML_SYCL_TARGET STREQUAL "INTEL")
70
+ target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
71
+ elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
72
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
73
+ target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl)
74
+ elseif (GGML_SYCL_TARGET STREQUAL "AMD")
75
+ if (GGML_SYCL_HIP_TARGET STREQUAL "")
76
+ message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_HIP_TARGET has not been set.")
77
+ endif()
78
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=${GGML_SYCL_HIP_TARGET}")
79
+ target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl)
80
+ endif()
81
+ endif()
@@ -15,6 +15,7 @@
15
15
 
16
16
  #include "concat.hpp"
17
17
  #include "common.hpp"
18
+ #include "conv.hpp"
18
19
  #include "convert.hpp"
19
20
  #include "dequantize.hpp"
20
21
  #include "dmmv.hpp"
@@ -23,5 +24,10 @@
23
24
  #include "rope.hpp"
24
25
  #include "norm.hpp"
25
26
  #include "softmax.hpp"
27
+ #include "tsembd.hpp"
28
+ #include "im2col.hpp"
29
+ #include "wkv6.hpp"
30
+ #include "outprod.hpp"
31
+ #include "element_wise.hpp"
26
32
 
27
33
  #endif // GGML_SYCL_BACKEND_HPP
@@ -51,3 +51,54 @@ void ggml_sycl_host_free(void* ptr) try {
51
51
  << ", line:" << __LINE__ << std::endl;
52
52
  std::exit(1);
53
53
  }
54
+
55
+ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
56
+ const int64_t max_range = std::numeric_limits<int>::max();
57
+ int64_t sycl_down_blk_size = block_size;
58
+ int64_t global_range = accumulate_block_num * sycl_down_blk_size;
59
+ while(global_range > max_range) {
60
+ sycl_down_blk_size /= 2;
61
+ global_range = accumulate_block_num * sycl_down_blk_size;
62
+ }
63
+ return sycl_down_blk_size;
64
+ }
65
+
66
+ void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
67
+ const ggml_tensor *src1, ggml_tensor *dst,
68
+ const ggml_sycl_op_flatten_t op) try {
69
+ const int64_t nrows0 = ggml_nrows(src0);
70
+
71
+ const bool use_src1 = src1 != nullptr;
72
+ const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
73
+
74
+ GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
75
+ GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
76
+
77
+ ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
78
+ ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
79
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
80
+
81
+ // dd = data device
82
+ float * src0_ddf = (float *) src0->data;
83
+ float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
84
+ float * dst_ddf = (float *) dst->data;
85
+
86
+ ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
87
+ ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
88
+ ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
89
+
90
+ ggml_sycl_set_device(ctx.device);
91
+ queue_ptr main_stream = ctx.stream();
92
+ // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
93
+ // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
94
+
95
+ // do the computation
96
+ op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
97
+ // print_ggml_tensor("tensor", dst);
98
+ }
99
+ catch (sycl::exception const &exc) {
100
+
101
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
102
+ << ", line:" << __LINE__ << std::endl;
103
+ std::exit(1);
104
+ }
@@ -19,6 +19,10 @@
19
19
  #include "dpct/helper.hpp"
20
20
  #include "ggml-sycl.h"
21
21
  #include "presets.hpp"
22
+ #if GGML_SYCL_DNNL
23
+ #include "dnnl.hpp"
24
+ #include "dnnl_sycl.hpp"
25
+ #endif
22
26
 
23
27
  #define GGML_COMMON_DECL_SYCL
24
28
  #define GGML_COMMON_IMPL_SYCL
@@ -276,6 +280,52 @@ struct ggml_backend_sycl_context {
276
280
  return stream(device, 0);
277
281
  }
278
282
 
283
+ #if GGML_SYCL_DNNL
284
+ dnnl::engine make_engine(sycl::queue* q) {
285
+ // Get the device associated with the queue
286
+ sycl::device dev = q->get_device();
287
+ // Get the context associated with the queue
288
+ sycl::context ctx = q->get_context();
289
+ const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
290
+ return eng;
291
+ }
292
+
293
+ std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
294
+ std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
295
+ dnnl::stream stream_dnnl(int device, int _stream) {
296
+ auto q = stream(device, _stream);
297
+ return stream_dnnl(q);
298
+ }
299
+ dnnl::engine engine_dnnl(sycl::queue* qptr) {
300
+ auto it = engine_map.find(qptr);
301
+ if (it == engine_map.end()) {
302
+ auto eng = make_engine(qptr);
303
+ engine_map[qptr] = eng;
304
+ return eng;
305
+ }
306
+ else
307
+ {
308
+ return it->second;
309
+ }
310
+ }
311
+ dnnl::stream stream_dnnl(sycl::queue* qptr) {
312
+ auto it = stream_map.find(qptr);
313
+ if (it == stream_map.end()) {
314
+ auto eng = engine_dnnl(qptr);
315
+ auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
316
+ stream_map[qptr] = stream;
317
+ return stream;
318
+ }
319
+ else
320
+ {
321
+ return it->second;
322
+ }
323
+ }
324
+ dnnl::stream stream_dnnl() {
325
+ return stream_dnnl(device, 0);
326
+ }
327
+ #endif
328
+
279
329
  // pool
280
330
  std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
281
331
 
@@ -352,4 +402,264 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
352
402
  return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
353
403
  }
354
404
 
405
+ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
406
+
407
+ typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
408
+ const ggml_tensor *src1,
409
+ ggml_tensor *dst, const float *src0_dd,
410
+ const float *src1_dd, float *dst_dd,
411
+ const queue_ptr &main_stream);
412
+
413
+ template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
414
+ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
415
+ int ne0, int ne1, int ne2, int ne3,
416
+ int ne10, int ne11, int ne12, int ne13,
417
+ /*int s0, */ int s1, int s2, int s3,
418
+ /*int s10,*/ int s11, int s12, int s13,
419
+ const sycl::nd_item<3> &item_ct1) {
420
+ const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
421
+ item_ct1.get_local_id(2);
422
+ const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
423
+ item_ct1.get_local_id(1));
424
+ const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
425
+ item_ct1.get_local_id(0)) /
426
+ ne3;
427
+ const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
428
+ item_ct1.get_local_id(0)) %
429
+ ne3;
430
+
431
+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
432
+ return;
433
+ }
434
+
435
+ const int i11 = i1 % ne11;
436
+ const int i12 = i2 % ne12;
437
+ const int i13 = i3 % ne13;
438
+
439
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
440
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
441
+ const size_t i_dst = i_src0;
442
+
443
+ const src0_t * src0_row = src0 + i_src0;
444
+ const src1_t * src1_row = src1 + i_src1;
445
+ dst_t * dst_row = dst + i_dst;
446
+
447
+ for (int i0 = i0s; i0 < ne0;
448
+ i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
449
+ const int i10 = i0 % ne10;
450
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
451
+ }
452
+ }
453
+
454
+ template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
455
+ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
456
+ int ne0, int ne1, int ne2, int ne3,
457
+ int ne10, int ne11, int ne12, int ne13,
458
+ /*int s0, */ int s1, int s2, int s3,
459
+ /*int s10,*/ int s11, int s12, int s13,
460
+ const sycl::nd_item<3> &item_ct1) {
461
+
462
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
463
+ item_ct1.get_local_id(2);
464
+
465
+ const int i3 = i/(ne2*ne1*ne0);
466
+ const int i2 = (i/(ne1*ne0)) % ne2;
467
+ const int i1 = (i/ne0) % ne1;
468
+ const int i0 = i % ne0;
469
+
470
+ if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
471
+ return;
472
+ }
473
+
474
+ const int i11 = i1 % ne11;
475
+ const int i12 = i2 % ne12;
476
+ const int i13 = i3 % ne13;
477
+
478
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
479
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
480
+ const size_t i_dst = i_src0;
481
+
482
+ const src0_t * src0_row = src0 + i_src0;
483
+ const src1_t * src1_row = src1 + i_src1;
484
+ dst_t * dst_row = dst + i_dst;
485
+
486
+ const int i10 = i0 % ne10;
487
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
488
+ }
489
+
490
+
491
+ template<float (*bin_op)(const float, const float)>
492
+ struct bin_bcast_sycl {
493
+ template <typename src0_t, typename src1_t, typename dst_t>
494
+ void operator()(ggml_backend_sycl_context & ctx,
495
+ const struct ggml_tensor *src0,
496
+ const struct ggml_tensor *src1, struct ggml_tensor *dst,
497
+ const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
498
+ queue_ptr stream) {
499
+
500
+ GGML_TENSOR_BINARY_OP_LOCALS
501
+
502
+ int nr0 = ne10/ne0;
503
+ int nr1 = ne11/ne1;
504
+ int nr2 = ne12/ne2;
505
+ int nr3 = ne13/ne3;
506
+
507
+ int nr[4] = { nr0, nr1, nr2, nr3 };
508
+
509
+ // collapse dimensions until first broadcast dimension
510
+ int64_t cne0[] = {ne0, ne1, ne2, ne3};
511
+ int64_t cne1[] = {ne10, ne11, ne12, ne13};
512
+ size_t cnb0[] = {nb0, nb1, nb2, nb3};
513
+ size_t cnb1[] = {nb10, nb11, nb12, nb13};
514
+ auto collapse = [](int64_t cne[]) {
515
+ cne[0] *= cne[1];
516
+ cne[1] = cne[2];
517
+ cne[2] = cne[3];
518
+ cne[3] = 1;
519
+ };
520
+
521
+ auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
522
+ cnb[1] *= cne[1];
523
+ cnb[2] *= cne[2];
524
+ cnb[3] *= cne[3];
525
+ };
526
+
527
+ for (int i = 0; i < 4; i++) {
528
+ if (nr[i] != 1) {
529
+ break;
530
+ }
531
+ if (i > 0) {
532
+ collapse_nb(cnb0, cne0);
533
+ collapse_nb(cnb1, cne1);
534
+ collapse(cne0);
535
+ collapse(cne1);
536
+ }
537
+ }
538
+ {
539
+ int64_t ne0 = cne0[0];
540
+ int64_t ne1 = cne0[1];
541
+ int64_t ne2 = cne0[2];
542
+ int64_t ne3 = cne0[3];
543
+
544
+ int64_t ne10 = cne1[0];
545
+ int64_t ne11 = cne1[1];
546
+ int64_t ne12 = cne1[2];
547
+ int64_t ne13 = cne1[3];
548
+
549
+ size_t nb0 = cnb0[0];
550
+ size_t nb1 = cnb0[1];
551
+ size_t nb2 = cnb0[2];
552
+ size_t nb3 = cnb0[3];
553
+
554
+ size_t nb10 = cnb1[0];
555
+ size_t nb11 = cnb1[1];
556
+ size_t nb12 = cnb1[2];
557
+ size_t nb13 = cnb1[3];
558
+
559
+ size_t s0 = nb0 / sizeof(dst_t);
560
+ size_t s1 = nb1 / sizeof(dst_t);
561
+ size_t s2 = nb2 / sizeof(dst_t);
562
+ size_t s3 = nb3 / sizeof(dst_t);
563
+
564
+ size_t s10 = nb10 / sizeof(src1_t);
565
+ size_t s11 = nb11 / sizeof(src1_t);
566
+ size_t s12 = nb12 / sizeof(src1_t);
567
+ size_t s13 = nb13 / sizeof(src1_t);
568
+
569
+ GGML_ASSERT(s0 == 1);
570
+ GGML_ASSERT(s10 == 1);
571
+
572
+ const int block_size = 128;
573
+
574
+ int64_t hne0 = std::max(ne0/2LL, 1LL);
575
+
576
+ sycl::range<3> block_dims(1, 1, 1);
577
+ block_dims[2] = std::min<unsigned int>(hne0, block_size);
578
+ block_dims[1] = std::min<unsigned int>(
579
+ ne1, block_size / (unsigned int)block_dims[2]);
580
+ block_dims[0] = std::min(
581
+ std::min<unsigned int>(
582
+ ne2 * ne3, block_size / (unsigned int)block_dims[2] /
583
+ (unsigned int)block_dims[1]),
584
+ 64U);
585
+
586
+ sycl::range<3> block_nums(
587
+ (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
588
+ (ne1 + block_dims[1] - 1) / block_dims[1],
589
+ (hne0 + block_dims[2] - 1) / block_dims[2]);
590
+
591
+ if (block_nums[0] > 65535) {
592
+ // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
593
+ int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
594
+ {
595
+ dpct::has_capability_or_fail(stream->get_device(),
596
+ {sycl::aspect::fp16});
597
+
598
+ stream->parallel_for(
599
+ sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
600
+ sycl::range<3>(1, 1, block_size),
601
+ sycl::range<3>(1, 1, block_size)),
602
+ [=](sycl::nd_item<3> item_ct1) {
603
+ k_bin_bcast_unravel<bin_op>(
604
+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
605
+ ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
606
+ s13, item_ct1);
607
+ });
608
+ }
609
+ } else {
610
+ /*
611
+ DPCT1049:16: The work-group size passed to the SYCL kernel may
612
+ exceed the limit. To get the device limit, query
613
+ info::device::max_work_group_size. Adjust the work-group size if
614
+ needed.
615
+ */
616
+ dpct::has_capability_or_fail(stream->get_device(),
617
+ {sycl::aspect::fp16});
618
+
619
+ stream->parallel_for(
620
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
621
+ [=](sycl::nd_item<3> item_ct1) {
622
+ k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
623
+ ne2, ne3, ne10, ne11, ne12, ne13,
624
+ s1, s2, s3, s11, s12, s13,
625
+ item_ct1);
626
+ });
627
+ }
628
+ }
629
+ }
630
+ };
631
+
632
+ template <class op>
633
+ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
634
+ const ggml_tensor *src1, ggml_tensor *dst,
635
+ const float *src0_dd, const float *src1_dd,
636
+ float *dst_dd,
637
+ const queue_ptr &main_stream) {
638
+
639
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
640
+ op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
641
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
642
+ op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
643
+ (sycl::half *)dst_dd, main_stream);
644
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
645
+ op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
646
+ main_stream);
647
+ } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
648
+ op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
649
+ main_stream);
650
+ } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
651
+ op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
652
+ main_stream);
653
+ } else {
654
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
655
+ ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
656
+ GGML_ABORT("fatal error");
657
+ }
658
+ }
659
+
660
+
661
+ void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
662
+ const ggml_tensor *src1, ggml_tensor *dst,
663
+ const ggml_sycl_op_flatten_t op);
664
+
355
665
  #endif // GGML_SYCL_COMMON_HPP
@@ -106,6 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
106
106
  concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
107
107
  });
108
108
  break;
109
+ // dim >=2 will be dispatched to the default path
109
110
  default:
110
111
  stream->parallel_for(
111
112
  sycl::nd_range<3>(gridDim *
@@ -0,0 +1,99 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #include "conv.hpp"
14
+
15
+ static void conv_transpose_1d_kernel(
16
+ const int s0, const int output_size,
17
+ const int src0_ne0, const int src0_ne1, const int src0_ne2,
18
+ const int src1_ne0, const int dst_ne0,
19
+ const float * src0, const float * src1, float * dst,
20
+ const sycl::nd_item<3> &item_ct1) {
21
+ int global_index = item_ct1.get_local_id(2) +
22
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
23
+ if (global_index >= output_size) {
24
+ return;
25
+ }
26
+
27
+ int out_index = global_index / dst_ne0;
28
+
29
+ float accumulator = 0;
30
+
31
+ for (int c = 0; c < src0_ne2; c++) {
32
+ int idx = global_index % dst_ne0;
33
+
34
+ int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
35
+ int input_offset = src1_ne0 * c;
36
+
37
+ for (int i = 0; i < src1_ne0; i++) {
38
+ if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
39
+ continue;
40
+ }
41
+ int weight_idx = idx - i*s0;
42
+
43
+ float kernel_weight = src0[kernel_offset + weight_idx];
44
+ float input_value = src1[input_offset+i];
45
+
46
+ accumulator += kernel_weight * input_value;
47
+ }
48
+ }
49
+ dst[global_index] = accumulator;
50
+ }
51
+
52
+ static void conv_transpose_1d_f32_f32_sycl(
53
+ const int s0, const int output_size,
54
+ const int src0_ne0, const int src0_ne1, const int src0_ne2,
55
+ const int src1_ne0, const int dst_ne0,
56
+ const float *src0, const float *src1, float *dst,
57
+ const queue_ptr& stream) {
58
+
59
+ const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
60
+ const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
61
+ const sycl::range<3> block_nums(1, 1, num_blocks);
62
+ stream->parallel_for(
63
+ sycl::nd_range<3>(
64
+ block_nums * block_dims, block_dims),
65
+ [=](sycl::nd_item<3> item_ct1) {
66
+ conv_transpose_1d_kernel(
67
+ s0, output_size,
68
+ src0_ne0, src0_ne1, src0_ne2,
69
+ src1_ne0, dst_ne0,
70
+ src0, src1, dst, item_ct1);
71
+ });
72
+ }
73
+
74
+ void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
75
+ const ggml_tensor *src1, ggml_tensor *dst) {
76
+ const float * src0_d = (const float *)src0->data;
77
+ const float * src1_d = (const float *)src1->data;
78
+
79
+ float * dst_d = (float *)dst->data;
80
+ dpct::queue_ptr stream = ctx.stream();
81
+
82
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
83
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
84
+
85
+ GGML_ASSERT(ggml_is_contiguous(src0));
86
+ GGML_ASSERT(ggml_is_contiguous(src1));
87
+
88
+ const int32_t * opts = (const int32_t *)dst->op_params;
89
+
90
+ const int s0 = opts[0];
91
+
92
+ const int64_t output_size = ggml_nelements(dst);
93
+
94
+ conv_transpose_1d_f32_f32_sycl(s0, output_size,
95
+ src0->ne[0], src0->ne[1], src0->ne[2],
96
+ src1->ne[0], dst->ne[0],
97
+ src0_d, src1_d, dst_d, stream);
98
+ }
99
+
@@ -0,0 +1,21 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_CONV_HPP
14
+ #define GGML_SYCL_CONV_HPP
15
+
16
+ #include "common.hpp"
17
+
18
+ void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
19
+ const ggml_tensor *src1, ggml_tensor *dst);
20
+
21
+ #endif // GGML_SYCL_CONV_HPP