@fugood/llama.node 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (320) hide show
  1. package/CMakeLists.txt +5 -2
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +8 -1
  17. package/package.json +1 -1
  18. package/patches/llama.patch +12 -12
  19. package/src/DetokenizeWorker.cpp +1 -1
  20. package/src/LlamaContext.cpp +33 -1
  21. package/src/LlamaContext.h +1 -0
  22. package/src/LoadSessionWorker.cpp +1 -0
  23. package/src/llama.cpp/.github/workflows/bench.yml +310 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +1315 -0
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +23 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +116 -0
  27. package/src/llama.cpp/.github/workflows/editorconfig.yml +27 -0
  28. package/src/llama.cpp/.github/workflows/gguf-publish.yml +44 -0
  29. package/src/llama.cpp/.github/workflows/labeler.yml +17 -0
  30. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +65 -0
  31. package/src/llama.cpp/.github/workflows/nix-ci.yml +72 -0
  32. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +22 -0
  33. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +36 -0
  34. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +35 -0
  35. package/src/llama.cpp/.github/workflows/python-lint.yml +23 -0
  36. package/src/llama.cpp/.github/workflows/python-type-check.yml +38 -0
  37. package/src/llama.cpp/.github/workflows/server.yml +183 -0
  38. package/src/llama.cpp/CMakeLists.txt +91 -1245
  39. package/src/llama.cpp/cmake/arm64-windows-llvm.cmake +1 -1
  40. package/src/llama.cpp/cmake/build-info.cmake +58 -0
  41. package/src/llama.cpp/cmake/git-vars.cmake +22 -0
  42. package/src/llama.cpp/common/CMakeLists.txt +4 -3
  43. package/src/llama.cpp/common/build-info.cpp.in +4 -0
  44. package/src/llama.cpp/common/common.cpp +1116 -877
  45. package/src/llama.cpp/common/common.h +191 -77
  46. package/src/llama.cpp/common/grammar-parser.cpp +118 -31
  47. package/src/llama.cpp/common/json-schema-to-grammar.cpp +346 -65
  48. package/src/llama.cpp/common/log.h +1 -1
  49. package/src/llama.cpp/common/ngram-cache.h +10 -3
  50. package/src/llama.cpp/common/sampling.cpp +19 -10
  51. package/src/llama.cpp/docs/build.md +353 -0
  52. package/src/llama.cpp/examples/CMakeLists.txt +22 -22
  53. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +6 -6
  55. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/batched/batched.cpp +52 -55
  57. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +20 -72
  59. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/chat-13B.bat +57 -0
  61. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/{finetune → cvector-generator}/CMakeLists.txt +2 -2
  63. package/src/llama.cpp/examples/cvector-generator/completions.txt +582 -0
  64. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +503 -0
  65. package/src/llama.cpp/examples/cvector-generator/mean.hpp +48 -0
  66. package/src/llama.cpp/examples/cvector-generator/negative.txt +4 -0
  67. package/src/llama.cpp/examples/cvector-generator/pca.hpp +325 -0
  68. package/src/llama.cpp/examples/cvector-generator/positive.txt +4 -0
  69. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +35 -0
  70. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  71. package/src/llama.cpp/examples/embedding/embedding.cpp +94 -46
  72. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +2 -2
  73. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +4 -6
  74. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  75. package/src/llama.cpp/examples/export-lora/export-lora.cpp +344 -386
  76. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +2 -2
  77. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +30 -25
  78. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/gguf/gguf.cpp +5 -0
  80. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +15 -0
  81. package/src/llama.cpp/examples/gguf-hash/deps/rotate-bits/rotate-bits.h +46 -0
  82. package/src/llama.cpp/examples/gguf-hash/deps/sha1/sha1.c +295 -0
  83. package/src/llama.cpp/examples/gguf-hash/deps/sha1/sha1.h +52 -0
  84. package/src/llama.cpp/examples/gguf-hash/deps/sha256/sha256.c +221 -0
  85. package/src/llama.cpp/examples/gguf-hash/deps/sha256/sha256.h +24 -0
  86. package/src/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.c +42 -0
  87. package/src/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.h +7093 -0
  88. package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +693 -0
  89. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  90. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +3 -3
  91. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  92. package/src/llama.cpp/examples/gritlm/gritlm.cpp +6 -2
  93. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  94. package/src/llama.cpp/examples/imatrix/imatrix.cpp +137 -176
  95. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/infill/infill.cpp +38 -153
  97. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +175 -94
  98. package/src/llama.cpp/examples/llama.android/app/build.gradle.kts +65 -0
  99. package/src/llama.cpp/examples/llama.android/build.gradle.kts +6 -0
  100. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +68 -0
  101. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +11 -7
  102. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +2 -2
  103. package/src/llama.cpp/examples/llama.android/settings.gradle.kts +18 -0
  104. package/src/llama.cpp/examples/llava/CMakeLists.txt +6 -5
  105. package/src/llama.cpp/examples/llava/android/build_64.sh +8 -0
  106. package/src/llama.cpp/examples/llava/clip.cpp +23 -14
  107. package/src/llama.cpp/examples/llava/llava-cli.cpp +8 -6
  108. package/src/llama.cpp/examples/llava/requirements.txt +3 -2
  109. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  110. package/src/llama.cpp/examples/lookahead/lookahead.cpp +2 -1
  111. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  112. package/src/llama.cpp/examples/lookup/lookup-create.cpp +2 -0
  113. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  114. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -2
  115. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  116. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  117. package/src/llama.cpp/examples/main/main.cpp +98 -75
  118. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +4 -5
  119. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  120. package/src/llama.cpp/examples/parallel/parallel.cpp +2 -1
  121. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  122. package/src/llama.cpp/examples/passkey/passkey.cpp +23 -43
  123. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  124. package/src/llama.cpp/examples/perplexity/perplexity.cpp +13 -10
  125. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  126. package/src/llama.cpp/examples/quantize/quantize.cpp +37 -34
  127. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +1 -1
  129. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  130. package/src/llama.cpp/examples/retrieval/retrieval.cpp +26 -77
  131. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +14 -7
  133. package/src/llama.cpp/examples/server/CMakeLists.txt +26 -2
  134. package/src/llama.cpp/examples/server/server.cpp +274 -671
  135. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  136. package/src/llama.cpp/examples/server/utils.hpp +28 -29
  137. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  138. package/src/llama.cpp/examples/simple/simple.cpp +21 -29
  139. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  140. package/src/llama.cpp/examples/speculative/speculative.cpp +2 -1
  141. package/src/llama.cpp/examples/sycl/CMakeLists.txt +1 -1
  142. package/src/llama.cpp/examples/sycl/build.sh +23 -0
  143. package/src/llama.cpp/examples/sycl/run-llama2.sh +36 -0
  144. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +33 -0
  145. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +9 -0
  146. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  147. package/src/llama.cpp/examples/tokenize/tokenize.cpp +16 -2
  148. package/src/llama.cpp/ggml/CMakeLists.txt +253 -0
  149. package/src/llama.cpp/{cmake → ggml/cmake}/FindSIMD.cmake +6 -6
  150. package/src/llama.cpp/{ggml-backend.h → ggml/include/ggml-backend.h} +22 -17
  151. package/src/llama.cpp/ggml/include/ggml-blas.h +23 -0
  152. package/src/llama.cpp/ggml/include/ggml-cann.h +125 -0
  153. package/src/llama.cpp/{ggml-cuda.h → ggml/include/ggml-cuda.h} +3 -0
  154. package/src/llama.cpp/{ggml-metal.h → ggml/include/ggml-metal.h} +1 -2
  155. package/src/llama.cpp/{ggml-sycl.h → ggml/include/ggml-sycl.h} +3 -10
  156. package/src/llama.cpp/{ggml.h → ggml/include/ggml.h} +80 -85
  157. package/src/llama.cpp/ggml/src/CMakeLists.txt +1329 -0
  158. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2193 -0
  159. package/src/llama.cpp/ggml/src/ggml-aarch64.h +39 -0
  160. package/src/llama.cpp/{ggml-alloc.c → ggml/src/ggml-alloc.c} +100 -49
  161. package/src/llama.cpp/{ggml-backend-impl.h → ggml/src/ggml-backend-impl.h} +20 -8
  162. package/src/llama.cpp/{ggml-backend.c → ggml/src/ggml-backend.c} +307 -167
  163. package/src/llama.cpp/ggml/src/ggml-blas.cpp +367 -0
  164. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +198 -0
  165. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +230 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +2944 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/common.h +282 -0
  169. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +32 -0
  170. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +17 -0
  171. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +223 -0
  172. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +186 -0
  173. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +180 -0
  174. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +193 -0
  175. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  176. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +208 -0
  177. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +206 -0
  178. package/src/llama.cpp/ggml/src/ggml-cann.cpp +2023 -0
  179. package/src/llama.cpp/{ggml-common.h → ggml/src/ggml-common.h} +41 -7
  180. package/src/llama.cpp/{ggml-impl.h → ggml/src/ggml-impl.h} +113 -9
  181. package/src/llama.cpp/{ggml-kompute.cpp → ggml/src/ggml-kompute.cpp} +33 -18
  182. package/src/llama.cpp/{ggml-quants.c → ggml/src/ggml-quants.c} +1460 -940
  183. package/src/llama.cpp/{ggml-quants.h → ggml/src/ggml-quants.h} +19 -20
  184. package/src/llama.cpp/{ggml-rpc.cpp → ggml/src/ggml-rpc.cpp} +95 -72
  185. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +27 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +53 -0
  187. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +355 -0
  188. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +195 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +21 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +547 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +27 -0
  192. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +698 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  195. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +3011 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  197. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.hpp +33 -0
  198. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1027 -0
  199. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  200. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +374 -0
  201. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +35 -0
  202. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +66 -0
  203. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +275 -0
  204. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +22 -0
  205. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +251 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +24 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +1140 -0
  208. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +5314 -0
  209. package/src/llama.cpp/{ggml-vulkan.cpp → ggml/src/ggml-vulkan.cpp} +1781 -1868
  210. package/src/llama.cpp/{ggml.c → ggml/src/ggml.c} +1245 -2087
  211. package/src/llama.cpp/{sgemm.cpp → ggml/src/llamafile/sgemm.cpp} +21 -24
  212. package/src/llama.cpp/{sgemm.h → ggml/src/llamafile/sgemm.h} +1 -1
  213. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +5 -0
  214. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +552 -0
  215. package/src/llama.cpp/{llama.h → include/llama.h} +175 -100
  216. package/src/llama.cpp/models/.editorconfig +1 -0
  217. package/src/llama.cpp/models/ggml-vocab-aquila.gguf +0 -0
  218. package/src/llama.cpp/models/ggml-vocab-baichuan.gguf +0 -0
  219. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf +0 -0
  220. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +112 -0
  221. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +46 -0
  222. package/src/llama.cpp/models/ggml-vocab-command-r.gguf +0 -0
  223. package/src/llama.cpp/models/ggml-vocab-command-r.gguf.inp +112 -0
  224. package/src/llama.cpp/models/ggml-vocab-command-r.gguf.out +46 -0
  225. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf +0 -0
  226. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +112 -0
  227. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +46 -0
  228. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf +0 -0
  229. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +112 -0
  230. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +46 -0
  231. package/src/llama.cpp/models/ggml-vocab-falcon.gguf +0 -0
  232. package/src/llama.cpp/models/ggml-vocab-falcon.gguf.inp +112 -0
  233. package/src/llama.cpp/models/ggml-vocab-falcon.gguf.out +46 -0
  234. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf +0 -0
  235. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +112 -0
  236. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +46 -0
  237. package/src/llama.cpp/models/ggml-vocab-gpt-neox.gguf +0 -0
  238. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf +0 -0
  239. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf +0 -0
  242. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +112 -0
  243. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +46 -0
  244. package/src/llama.cpp/models/ggml-vocab-mpt.gguf +0 -0
  245. package/src/llama.cpp/models/ggml-vocab-mpt.gguf.inp +112 -0
  246. package/src/llama.cpp/models/ggml-vocab-mpt.gguf.out +46 -0
  247. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf +0 -0
  248. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +112 -0
  249. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf.out +46 -0
  250. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf +0 -0
  251. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +112 -0
  252. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf.out +46 -0
  253. package/src/llama.cpp/models/ggml-vocab-refact.gguf +0 -0
  254. package/src/llama.cpp/models/ggml-vocab-refact.gguf.inp +112 -0
  255. package/src/llama.cpp/models/ggml-vocab-refact.gguf.out +46 -0
  256. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf +0 -0
  257. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +112 -0
  258. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf.out +46 -0
  259. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  260. package/src/llama.cpp/requirements/requirements-all.txt +12 -0
  261. package/src/llama.cpp/requirements/requirements-compare-llama-bench.txt +2 -0
  262. package/src/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -0
  263. package/src/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +3 -0
  264. package/src/llama.cpp/requirements/{requirements-convert.txt → requirements-convert_legacy_llama.txt} +1 -1
  265. package/src/llama.cpp/requirements/requirements-convert_llama_ggml_to_gguf.txt +1 -0
  266. package/src/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  267. package/src/llama.cpp/requirements/requirements-pydantic.txt +3 -0
  268. package/src/llama.cpp/requirements/requirements-test-tokenizer-random.txt +1 -0
  269. package/src/llama.cpp/requirements.txt +5 -4
  270. package/src/llama.cpp/scripts/build-info.sh +30 -0
  271. package/src/llama.cpp/scripts/install-oneapi.bat +19 -0
  272. package/src/llama.cpp/src/CMakeLists.txt +33 -0
  273. package/src/llama.cpp/src/llama-grammar.cpp +539 -0
  274. package/src/llama.cpp/src/llama-grammar.h +39 -0
  275. package/src/llama.cpp/src/llama-impl.h +26 -0
  276. package/src/llama.cpp/src/llama-sampling.cpp +635 -0
  277. package/src/llama.cpp/src/llama-sampling.h +56 -0
  278. package/src/llama.cpp/src/llama-vocab.cpp +1721 -0
  279. package/src/llama.cpp/src/llama-vocab.h +130 -0
  280. package/src/llama.cpp/{llama.cpp → src/llama.cpp} +5979 -5260
  281. package/src/llama.cpp/{unicode-data.cpp → src/unicode-data.cpp} +851 -802
  282. package/src/llama.cpp/{unicode.cpp → src/unicode.cpp} +52 -30
  283. package/src/llama.cpp/{unicode.h → src/unicode.h} +5 -1
  284. package/src/llama.cpp/tests/CMakeLists.txt +19 -20
  285. package/src/llama.cpp/tests/test-backend-ops.cpp +245 -67
  286. package/src/llama.cpp/tests/test-chat-template.cpp +57 -3
  287. package/src/llama.cpp/tests/test-double-float.cpp +2 -2
  288. package/src/llama.cpp/tests/test-grad0.cpp +2 -2
  289. package/src/llama.cpp/tests/test-grammar-integration.cpp +978 -31
  290. package/src/llama.cpp/tests/test-grammar-parser.cpp +423 -158
  291. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +508 -135
  292. package/src/llama.cpp/tests/test-llama-grammar.cpp +15 -9
  293. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -1
  294. package/src/llama.cpp/tests/test-quantize-perf.cpp +1 -1
  295. package/src/llama.cpp/tests/test-rope.cpp +3 -4
  296. package/src/llama.cpp/tests/test-sampling.cpp +5 -5
  297. package/src/llama.cpp/tests/test-tokenizer-0.cpp +6 -6
  298. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +20 -15
  299. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +22 -11
  300. package/bin/darwin/arm64/default.metallib +0 -0
  301. package/bin/darwin/x64/default.metallib +0 -0
  302. package/src/llama.cpp/examples/beam-search/CMakeLists.txt +0 -5
  303. package/src/llama.cpp/examples/beam-search/beam-search.cpp +0 -188
  304. package/src/llama.cpp/examples/finetune/finetune.cpp +0 -1862
  305. package/src/llama.cpp/examples/llama.android/llama/CMakeLists.txt +0 -55
  306. package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +0 -5
  307. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +0 -1253
  308. package/src/llama.cpp/ggml-opencl.cpp +0 -2305
  309. package/src/llama.cpp/ggml-opencl.h +0 -36
  310. package/src/llama.cpp/ggml-sycl.cpp +0 -17340
  311. package/src/llama.cpp/ggml-vulkan-shaders.hpp +0 -81211
  312. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf-update.txt +0 -2
  313. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +0 -2
  314. package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +0 -1
  315. package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +0 -24
  316. /package/src/llama.cpp/{ggml-alloc.h → ggml/include/ggml-alloc.h} +0 -0
  317. /package/src/llama.cpp/{ggml-kompute.h → ggml/include/ggml-kompute.h} +0 -0
  318. /package/src/llama.cpp/{ggml-rpc.h → ggml/include/ggml-rpc.h} +0 -0
  319. /package/src/llama.cpp/{ggml-vulkan.h → ggml/include/ggml-vulkan.h} +0 -0
  320. /package/src/llama.cpp/{unicode-data.h → src/unicode-data.h} +0 -0
@@ -0,0 +1,355 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_COMMON_HPP
14
+ #define GGML_SYCL_COMMON_HPP
15
+
16
+ #include <fstream>
17
+ #include <iostream>
18
+
19
+ #include "dpct/helper.hpp"
20
+ #include "ggml-sycl.h"
21
+ #include "presets.hpp"
22
+
23
+ #define GGML_COMMON_DECL_SYCL
24
+ #define GGML_COMMON_IMPL_SYCL
25
+ #include "ggml-common.h"
26
+
27
+ void* ggml_sycl_host_malloc(size_t size);
28
+ void ggml_sycl_host_free(void* ptr);
29
+
30
+ static int g_ggml_sycl_debug = 0;
31
+ #define GGML_SYCL_DEBUG(...) \
32
+ do { \
33
+ if (g_ggml_sycl_debug) \
34
+ fprintf(stderr, __VA_ARGS__); \
35
+ } while (0)
36
+
37
+ #define CHECK_TRY_ERROR(expr) \
38
+ [&]() { \
39
+ try { \
40
+ expr; \
41
+ return dpct::success; \
42
+ } catch (std::exception const& e) { \
43
+ std::cerr << e.what() << "\nException caught at file:" << __FILE__ \
44
+ << ", line:" << __LINE__ << ", func:" << __func__ \
45
+ << std::endl; \
46
+ return dpct::default_error; \
47
+ } \
48
+ }()
49
+
50
+
51
+ #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
52
+ #define VER_4VEC 610 // todo for hardward optimize.
53
+ #define VER_GEN9 700 // todo for hardward optimize.
54
+ #define VER_GEN12 1000000 // todo for hardward optimize.
55
+ #define VER_GEN13 (VER_GEN12 + 1030) // todo for hardward optimize.
56
+
57
+ #define GGML_SYCL_MAX_NODES 8192 // TODO: adapt to hardwares
58
+
59
+ // define for XMX in Intel GPU
60
+ // TODO: currently, it's not used for XMX really.
61
+ #if !defined(GGML_SYCL_FORCE_MMQ)
62
+ #define SYCL_USE_XMX
63
+ #endif
64
+
65
+ // max batch size to use MMQ kernels when tensor cores are available
66
+ #define MMQ_MAX_BATCH_SIZE 32
67
+
68
+ #if defined(_MSC_VER)
69
+ #pragma warning(disable : 4244 4267) // possible loss of data
70
+ #endif
71
+
72
+ // dmmv = dequantize_mul_mat_vec
73
+ #ifndef GGML_SYCL_DMMV_X
74
+ #define GGML_SYCL_DMMV_X 32
75
+ #endif
76
+ #ifndef GGML_SYCL_MMV_Y
77
+ #define GGML_SYCL_MMV_Y 1
78
+ #endif
79
+
80
+ typedef sycl::queue *queue_ptr;
81
+
82
+ enum ggml_sycl_backend_gpu_mode {
83
+ SYCL_UNSET_GPU_MODE = -1,
84
+ SYCL_SINGLE_GPU_MODE = 0,
85
+ SYCL_MUL_GPU_MODE
86
+ };
87
+
88
+ static_assert(sizeof(sycl::half) == sizeof(ggml_fp16_t), "wrong fp16 size");
89
+
90
+ static void crash() {
91
+ int* ptr = NULL;
92
+ *ptr = 0;
93
+ }
94
+
95
+ [[noreturn]] static void ggml_sycl_error(
96
+ const char* stmt,
97
+ const char* func,
98
+ const char* file,
99
+ const int line,
100
+ const char* msg) {
101
+ fprintf(stderr, "SYCL error: %s: %s\n", stmt, msg);
102
+ fprintf(stderr, " in function %s at %s:%d\n", func, file, line);
103
+ GGML_ABORT("SYCL error");
104
+ }
105
+
106
+ #define SYCL_CHECK(err) \
107
+ do { \
108
+ auto err_ = (err); \
109
+ if (err_ != 0) \
110
+ ggml_sycl_error( \
111
+ #err, \
112
+ __func__, \
113
+ __FILE__, \
114
+ __LINE__, \
115
+ "Meet error in this line code!"); \
116
+ } while (0)
117
+
118
+ #if DPCT_COMPAT_RT_VERSION >= 11100
119
+ #define GGML_SYCL_ASSUME(x) __builtin_assume(x)
120
+ #else
121
+ #define GGML_SYCL_ASSUME(x)
122
+ #endif // DPCT_COMPAT_RT_VERSION >= 11100
123
+
124
+ #ifdef GGML_SYCL_F16
125
+ typedef sycl::half dfloat; // dequantize float
126
+ typedef sycl::half2 dfloat2;
127
+ #else
128
+ typedef float dfloat; // dequantize float
129
+ typedef sycl::float2 dfloat2;
130
+ #endif // GGML_SYCL_F16
131
+
132
+ #define MMVQ_MAX_BATCH_SIZE 8
133
+
134
+ static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
135
+
136
+ static int g_all_sycl_device_count = -1;
137
+ static bool g_ggml_backend_sycl_buffer_type_initialized = false;
138
+
139
+ static ggml_sycl_backend_gpu_mode g_ggml_sycl_backend_gpu_mode =
140
+ SYCL_UNSET_GPU_MODE;
141
+
142
+ static void* g_scratch_buffer = nullptr;
143
+ static size_t g_scratch_size = 0; // disabled by default
144
+ static size_t g_scratch_offset = 0;
145
+
146
+ [[noreturn]] static inline void bad_arch(const sycl::stream& stream_ct1) {
147
+ stream_ct1 << "ERROR: ggml-sycl was compiled without support for the "
148
+ "current GPU architecture.\n";
149
+ // __trap();
150
+ std::exit(1);
151
+
152
+ (void)bad_arch; // suppress unused function warning
153
+ }
154
+
155
+ int get_current_device_id();
156
+
157
+ inline dpct::err0 ggml_sycl_set_device(const int device) try {
158
+
159
+ int current_device_id;
160
+ SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
161
+
162
+ // GGML_SYCL_DEBUG("ggml_sycl_set_device device_id=%d,
163
+ // current_device_id=%d\n", device, current_device);
164
+ if (device == current_device_id) {
165
+ return 0;
166
+ }
167
+
168
+ return CHECK_TRY_ERROR(dpct::select_device(device));
169
+ } catch (sycl::exception const& exc) {
170
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
171
+ << ", line:" << __LINE__ << std::endl;
172
+ crash();
173
+ std::exit(1);
174
+ }
175
+
176
+ //////////////////////
177
+
178
+ struct ggml_sycl_device_info {
179
+ int device_count;
180
+
181
+ struct sycl_device_info {
182
+ int cc; // compute capability
183
+ // int nsm; // number of streaming multiprocessors
184
+ // size_t smpb; // max. shared memory per block
185
+ bool vmm; // virtual memory support
186
+ size_t total_vram;
187
+ };
188
+
189
+ sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {};
190
+
191
+ std::array<float, GGML_SYCL_MAX_DEVICES> default_tensor_split = {};
192
+
193
+ int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0};
194
+ };
195
+
196
+ const ggml_sycl_device_info & ggml_sycl_info();
197
+
198
+ struct ggml_sycl_pool {
199
+ virtual ~ggml_sycl_pool() = default;
200
+
201
+ virtual void * alloc(size_t size, size_t * actual_size) = 0;
202
+ virtual void free(void * ptr, size_t size) = 0;
203
+ };
204
+
205
+ template<typename T>
206
+ struct ggml_sycl_pool_alloc {
207
+ ggml_sycl_pool * pool = nullptr;
208
+ T * ptr = nullptr;
209
+ size_t actual_size = 0;
210
+
211
+ explicit ggml_sycl_pool_alloc(ggml_sycl_pool & pool) : pool(&pool) {
212
+ }
213
+
214
+ ggml_sycl_pool_alloc(ggml_sycl_pool & pool, size_t size) : pool(&pool) {
215
+ alloc(size);
216
+ }
217
+
218
+ ~ggml_sycl_pool_alloc() {
219
+ if (ptr != nullptr) {
220
+ pool->free(ptr, actual_size);
221
+ }
222
+ }
223
+
224
+ // size is in number of elements
225
+ T * alloc(size_t size) {
226
+ GGML_ASSERT(pool != nullptr);
227
+ GGML_ASSERT(ptr == nullptr);
228
+ ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
229
+ return ptr;
230
+ }
231
+
232
+ T * alloc(ggml_sycl_pool & pool, size_t size) {
233
+ this->pool = &pool;
234
+ return alloc(size);
235
+ }
236
+
237
+ T * get() {
238
+ return ptr;
239
+ }
240
+
241
+ ggml_sycl_pool_alloc() = default;
242
+ ggml_sycl_pool_alloc(const ggml_sycl_pool_alloc &) = delete;
243
+ ggml_sycl_pool_alloc(ggml_sycl_pool_alloc &&) = delete;
244
+ ggml_sycl_pool_alloc& operator=(const ggml_sycl_pool_alloc &) = delete;
245
+ ggml_sycl_pool_alloc& operator=(ggml_sycl_pool_alloc &&) = delete;
246
+ };
247
+
248
+ // backend interface
249
+
250
+ struct ggml_tensor_extra_gpu {
251
+ void* data_device[GGML_SYCL_MAX_DEVICES]; // 1 pointer for each device for split
252
+ // tensors
253
+ dpct::event_ptr events[GGML_SYCL_MAX_DEVICES]
254
+ [GGML_SYCL_MAX_STREAMS]; // events for synchronizing multiple GPUs
255
+ };
256
+
257
+ struct ggml_backend_sycl_context {
258
+ int device;
259
+ std::string name;
260
+
261
+ queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } };
262
+
263
+ explicit ggml_backend_sycl_context(int device) :
264
+ device(device),
265
+ name(GGML_SYCL_NAME + std::to_string(device)) {
266
+ }
267
+
268
+ queue_ptr stream(int device, int stream) {
269
+ if (qptrs[device][stream] == nullptr) {
270
+ qptrs[device][stream] = &(dpct::get_device(device).default_queue());
271
+ }
272
+ return qptrs[device][stream];
273
+ }
274
+
275
+ queue_ptr stream() {
276
+ return stream(device, 0);
277
+ }
278
+
279
+ // pool
280
+ std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
281
+
282
+ static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
283
+
284
+ ggml_sycl_pool & pool(int device) {
285
+ if (pools[device] == nullptr) {
286
+ pools[device] = new_pool_for_device(stream(device,0), device);
287
+ }
288
+ return *pools[device];
289
+ }
290
+
291
+ ggml_sycl_pool & pool() {
292
+ return pool(device);
293
+ }
294
+ };
295
+
296
+ // common device functions
297
+
298
+ static __dpct_inline__ float warp_reduce_sum(float x,
299
+ const sycl::nd_item<3>& item_ct1) {
300
+ #pragma unroll
301
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
302
+ /*
303
+ DPCT1096:98: The right-most dimension of the work-group used in the SYCL
304
+ kernel that calls this function may be less than "32". The function
305
+ "dpct::permute_sub_group_by_xor" may return an unexpected result on the
306
+ CPU device. Modify the size of the work-group to ensure that the value
307
+ of the right-most dimension is a multiple of "32".
308
+ */
309
+ x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask);
310
+ }
311
+ return x;
312
+ }
313
+
314
+ static __dpct_inline__ sycl::float2
315
+ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
316
+ #pragma unroll
317
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
318
+ a.x() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.x(),
319
+ mask);
320
+ a.y() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.y(),
321
+ mask);
322
+ }
323
+ return a;
324
+ }
325
+
326
+ static __dpct_inline__ float warp_reduce_max(float x,
327
+ const sycl::nd_item<3>& item_ct1) {
328
+ #pragma unroll
329
+ for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
330
+ /*
331
+ DPCT1096:97: The right-most dimension of the work-group used in the SYCL
332
+ kernel that calls this function may be less than "32". The function
333
+ "dpct::permute_sub_group_by_xor" may return an unexpected result on the
334
+ CPU device. Modify the size of the work-group to ensure that the value
335
+ of the right-most dimension is a multiple of "32".
336
+ */
337
+ x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
338
+ item_ct1.get_sub_group(), x, mask));
339
+ }
340
+ return x;
341
+ }
342
+
343
+ // Helper for vec loading aligned data
344
+ template <typename Tp, int n>
345
+ inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
346
+ return *reinterpret_cast<const sycl::vec<Tp, n>*>(aligned_ptr);
347
+ }
348
+
349
+ // Helper for accessing pointers with no warnings
350
+ template <typename Tp, int dim>
351
+ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
352
+ return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
353
+ }
354
+
355
+ #endif // GGML_SYCL_COMMON_HPP
@@ -0,0 +1,195 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #include "concat.hpp"
14
+ #include "common.hpp"
15
+
16
+ static void concat_f32_dim0(const float *x, const float *y, float *dst,
17
+ const int ne0, const int ne00,
18
+ const sycl::nd_item<3> &item_ct1) {
19
+ int nidx = item_ct1.get_local_id(2) +
20
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
21
+ if (nidx >= ne0) {
22
+ return;
23
+ }
24
+ // operation
25
+ int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
26
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
27
+ if (nidx < ne00) { // src0
28
+ int offset_src = nidx + item_ct1.get_group(1) * ne00 +
29
+ item_ct1.get_group(0) * ne00 * item_ct1.get_group_range(1);
30
+ dst[offset_dst] = x[offset_src];
31
+ } else {
32
+ int offset_src =
33
+ nidx - ne00 + item_ct1.get_group(1) * (ne0 - ne00) +
34
+ item_ct1.get_group(0) * (ne0 - ne00) * item_ct1.get_group_range(1);
35
+ dst[offset_dst] = y[offset_src];
36
+ }
37
+ }
38
+
39
+ static void concat_f32_dim1(const float *x, const float *y, float *dst,
40
+ const int ne0, const int ne01,
41
+ const sycl::nd_item<3> &item_ct1) {
42
+ int nidx = item_ct1.get_local_id(2) +
43
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
44
+ if (nidx >= ne0) {
45
+ return;
46
+ }
47
+ // operation
48
+ int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
49
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
50
+ if (item_ct1.get_group(1) < ne01) { // src0
51
+ int offset_src =
52
+ nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * ne01;
53
+ dst[offset_dst] = x[offset_src];
54
+ } else {
55
+ int offset_src =
56
+ nidx + (item_ct1.get_group(1) - ne01) * ne0 +
57
+ item_ct1.get_group(0) * ne0 * (item_ct1.get_group_range(1) - ne01);
58
+ dst[offset_dst] = y[offset_src];
59
+ }
60
+ }
61
+
62
+ static void concat_f32_dim2(const float *x, const float *y, float *dst,
63
+ const int ne0, const int ne02,
64
+ const sycl::nd_item<3> &item_ct1) {
65
+ int nidx = item_ct1.get_local_id(2) +
66
+ item_ct1.get_group(2) * item_ct1.get_local_range(2);
67
+ if (nidx >= ne0) {
68
+ return;
69
+ }
70
+ // operation
71
+ int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
72
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
73
+ if (item_ct1.get_group(0) < ne02) { // src0
74
+ int offset_src = nidx + item_ct1.get_group(1) * ne0 +
75
+ item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
76
+ dst[offset_dst] = x[offset_src];
77
+ } else {
78
+ int offset_src =
79
+ nidx + item_ct1.get_group(1) * ne0 +
80
+ (item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1);
81
+ dst[offset_dst] = y[offset_src];
82
+ }
83
+ }
84
+
85
+ static void concat_f32_sycl(const float *x, const float *y, float *dst,
86
+ int ne00, int ne01, int ne02, int ne0, int ne1,
87
+ int ne2, int dim, queue_ptr stream) {
88
+ int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
89
+ sycl::range<3> gridDim(ne2, ne1, num_blocks);
90
+ switch (dim) {
91
+ case 0:
92
+ stream->parallel_for(
93
+ sycl::nd_range<3>(gridDim *
94
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
95
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
96
+ [=](sycl::nd_item<3> item_ct1) {
97
+ concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
98
+ });
99
+ break;
100
+ case 1:
101
+ stream->parallel_for(
102
+ sycl::nd_range<3>(gridDim *
103
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
104
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
105
+ [=](sycl::nd_item<3> item_ct1) {
106
+ concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
107
+ });
108
+ break;
109
+ default:
110
+ stream->parallel_for(
111
+ sycl::nd_range<3>(gridDim *
112
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
113
+ sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
114
+ [=](sycl::nd_item<3> item_ct1) {
115
+ concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
116
+ });
117
+ break;
118
+ }
119
+ }
120
+
121
+ // non-contiguous kernel (slow)
122
+ static void concat_f32_sycl_non_cont(
123
+ queue_ptr stream, const char *src0, const char *src1, char *dst,
124
+ int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, uint64_t nb00,
125
+ uint64_t nb01, uint64_t nb02, uint64_t nb03, int64_t /*ne10*/,
126
+ int64_t /*ne11*/, int64_t /*ne12*/, int64_t /*ne13*/, uint64_t nb10,
127
+ uint64_t nb11, uint64_t nb12, uint64_t nb13, int64_t ne0, int64_t ne1,
128
+ int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
129
+ uint64_t nb3, int32_t dim) {
130
+ sycl::range<3> gridDim(ne3, ne2, ne1);
131
+ stream->parallel_for(
132
+ sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
133
+ [=](sycl::nd_item<3> item_ct1) {
134
+ int64_t i3 = item_ct1.get_group(0);
135
+ int64_t i2 = item_ct1.get_group(1);
136
+ int64_t i1 = item_ct1.get_group(2);
137
+
138
+ int64_t o[4] = {0, 0, 0, 0};
139
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
140
+
141
+ const float *x;
142
+
143
+ for (int i0 = item_ct1.get_local_id(2); i0 < ne0;
144
+ i0 += item_ct1.get_local_range(2)) {
145
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
146
+ x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 +
147
+ (i0)*nb00);
148
+ } else {
149
+ x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 +
150
+ (i1 - o[1]) * nb11 + (i0 - o[0]) * nb10);
151
+ }
152
+
153
+ float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
154
+
155
+ *y = *x;
156
+ }
157
+ });
158
+ }
159
+
160
+ void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
161
+ const ggml_tensor *src1, ggml_tensor *dst) {
162
+ queue_ptr stream = ctx.stream();
163
+
164
+ const int32_t dim = ((int32_t *)dst->op_params)[0];
165
+
166
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
167
+ const float *src0_d = (const float *)src0->data;
168
+ const float *src1_d = (const float *)src1->data;
169
+
170
+ float *dst_d = (float *)dst->data;
171
+
172
+ if (dim != 3) {
173
+ for (int i3 = 0; i3 < dst->ne[3]; i3++) {
174
+ concat_f32_sycl(
175
+ src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4),
176
+ dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1],
177
+ src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
178
+ }
179
+ } else {
180
+ const size_t size0 = ggml_nbytes(src0);
181
+ const size_t size1 = ggml_nbytes(src1);
182
+
183
+ SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait()));
184
+ SYCL_CHECK(CHECK_TRY_ERROR(
185
+ stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait()));
186
+ }
187
+ } else
188
+ concat_f32_sycl_non_cont(
189
+ stream, (const char *)src0->data, (const char *)src1->data,
190
+ (char *)dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
191
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
192
+ src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1],
193
+ src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2],
194
+ dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
195
+ }
@@ -0,0 +1,21 @@
1
+ //
2
+ // MIT license
3
+ // Copyright (C) 2024 Intel Corporation
4
+ // SPDX-License-Identifier: MIT
5
+ //
6
+
7
+ //
8
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9
+ // See https://llvm.org/LICENSE.txt for license information.
10
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11
+ //
12
+
13
+ #ifndef GGML_SYCL_CONCAT_HPP
14
+ #define GGML_SYCL_CONCAT_HPP
15
+
16
+ #include "common.hpp"
17
+
18
+ void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
19
+ const ggml_tensor *src1, ggml_tensor *dst);
20
+
21
+ #endif // GGML_SYCL_CONCAT_HPP