@fugood/llama.node 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. package/CMakeLists.txt +7 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/DetokenizeWorker.cpp +1 -1
  19. package/src/EmbeddingWorker.cpp +17 -7
  20. package/src/EmbeddingWorker.h +2 -1
  21. package/src/LlamaCompletionWorker.cpp +8 -8
  22. package/src/LlamaCompletionWorker.h +2 -2
  23. package/src/LlamaContext.cpp +89 -27
  24. package/src/LlamaContext.h +2 -0
  25. package/src/TokenizeWorker.cpp +1 -1
  26. package/src/common.hpp +4 -4
  27. package/src/llama.cpp/.github/workflows/build.yml +240 -168
  28. package/src/llama.cpp/.github/workflows/docker.yml +8 -8
  29. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  31. package/src/llama.cpp/CMakeLists.txt +14 -6
  32. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/cmake/common.cmake +33 -0
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  36. package/src/llama.cpp/common/CMakeLists.txt +6 -4
  37. package/src/llama.cpp/common/arg.cpp +986 -770
  38. package/src/llama.cpp/common/arg.h +22 -22
  39. package/src/llama.cpp/common/common.cpp +212 -351
  40. package/src/llama.cpp/common/common.h +204 -117
  41. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  42. package/src/llama.cpp/common/log.cpp +50 -50
  43. package/src/llama.cpp/common/log.h +18 -18
  44. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  45. package/src/llama.cpp/common/ngram-cache.h +19 -19
  46. package/src/llama.cpp/common/sampling.cpp +163 -121
  47. package/src/llama.cpp/common/sampling.h +41 -20
  48. package/src/llama.cpp/common/speculative.cpp +274 -0
  49. package/src/llama.cpp/common/speculative.h +28 -0
  50. package/src/llama.cpp/docs/build.md +134 -161
  51. package/src/llama.cpp/examples/CMakeLists.txt +33 -14
  52. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/batched/batched.cpp +19 -18
  54. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  55. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  56. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  57. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  58. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  60. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  61. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  63. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  64. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  65. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  66. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  67. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  69. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  70. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  71. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  72. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  73. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  75. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  77. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  78. package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
  79. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  80. package/src/llama.cpp/examples/infill/infill.cpp +41 -87
  81. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
  83. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
  84. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  85. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  86. package/src/llama.cpp/examples/llava/clip.cpp +263 -66
  87. package/src/llama.cpp/examples/llava/clip.h +8 -2
  88. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  89. package/src/llama.cpp/examples/llava/llava.cpp +83 -22
  90. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  91. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  92. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  94. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  95. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  96. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  97. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
  98. package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
  99. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  100. package/src/llama.cpp/examples/main/main.cpp +73 -114
  101. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  102. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  104. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  105. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  106. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  108. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  109. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  110. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  111. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  112. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  113. package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
  114. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  115. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  116. package/src/llama.cpp/examples/run/run.cpp +911 -0
  117. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  118. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
  119. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
  120. package/src/llama.cpp/examples/server/server.cpp +2073 -1339
  121. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  122. package/src/llama.cpp/examples/server/utils.hpp +354 -277
  123. package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
  124. package/src/llama.cpp/examples/simple/simple.cpp +130 -94
  125. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  126. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
  127. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
  129. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  130. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  131. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
  133. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  134. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  135. package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
  136. package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
  137. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  138. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  139. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  140. package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
  141. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  142. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  143. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  144. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  145. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  146. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  147. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  148. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  149. package/src/llama.cpp/ggml/include/ggml.h +159 -417
  150. package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
  151. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
  152. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
  153. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
  154. package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
  155. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  156. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
  157. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
  158. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  159. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  160. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
  161. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  162. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  163. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  164. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  165. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  169. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  170. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
  171. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  172. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  173. package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  174. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  175. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  176. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  177. package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
  178. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  179. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  180. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  181. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
  182. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  183. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  184. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  185. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  186. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  187. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
  188. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
  189. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
  190. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
  192. package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
  193. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  194. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
  195. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
  196. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  197. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
  198. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  199. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  200. package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
  201. package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
  202. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  203. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  204. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
  205. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
  208. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
  209. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  210. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  211. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  212. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
  213. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  214. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  215. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  216. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
  217. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  218. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  219. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
  220. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
  221. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  222. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  223. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  224. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  225. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  226. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  227. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  228. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  229. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  230. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  231. package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
  232. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
  233. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
  234. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
  235. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  236. package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
  237. package/src/llama.cpp/include/llama-cpp.h +25 -0
  238. package/src/llama.cpp/include/llama.h +93 -52
  239. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  242. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  243. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  244. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  245. package/src/llama.cpp/src/CMakeLists.txt +4 -8
  246. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  247. package/src/llama.cpp/src/llama-grammar.h +2 -5
  248. package/src/llama.cpp/src/llama-sampling.cpp +779 -194
  249. package/src/llama.cpp/src/llama-sampling.h +21 -2
  250. package/src/llama.cpp/src/llama-vocab.cpp +55 -10
  251. package/src/llama.cpp/src/llama-vocab.h +35 -11
  252. package/src/llama.cpp/src/llama.cpp +4317 -2979
  253. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  254. package/src/llama.cpp/src/unicode.cpp +62 -51
  255. package/src/llama.cpp/src/unicode.h +9 -10
  256. package/src/llama.cpp/tests/CMakeLists.txt +48 -38
  257. package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
  258. package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
  259. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  260. package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
  261. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  262. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  263. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  264. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  265. package/src/llama.cpp/tests/test-log.cpp +2 -2
  266. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  267. package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
  268. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  269. package/src/llama.cpp/tests/test-rope.cpp +62 -20
  270. package/src/llama.cpp/tests/test-sampling.cpp +163 -138
  271. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  272. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  273. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  274. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  275. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  276. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  277. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  278. package/src/llama.cpp/common/train.cpp +0 -1515
  279. package/src/llama.cpp/common/train.h +0 -233
  280. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  281. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  282. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
  283. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
  284. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  285. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  286. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -0,0 +1,84 @@
1
+ if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$")
2
+ message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD")
3
+ endif()
4
+
5
+ check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL)
6
+
7
+ if (DEFINED ENV{ONEAPI_ROOT})
8
+ message(STATUS "Using oneAPI Release SYCL compiler (icpx).")
9
+ elseif(SUPPORTS_SYCL)
10
+ message(WARNING "Using open-source SYCL compiler (clang++). Didn't detect ENV {ONEAPI_ROOT}.
11
+ If you expected the oneAPI Release compiler, please install oneAPI & source it, like:
12
+ source /opt/intel/oneapi/setvars.sh")
13
+ else()
14
+ message(FATAL_ERROR, "C++ compiler lacks SYCL support.")
15
+ endif()
16
+ message(STATUS "SYCL found")
17
+ #todo: AOT
18
+
19
+ ggml_add_backend_library(ggml-sycl
20
+ ggml-sycl.cpp
21
+ ../../include/ggml-sycl.h
22
+ )
23
+
24
+ if (GGML_SYCL_F16)
25
+ if (GGML_SYCL_TARGET STREQUAL "AMD")
26
+ message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
27
+ endif()
28
+ add_compile_definitions(GGML_SYCL_F16)
29
+ endif()
30
+
31
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl")
32
+
33
+ if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
34
+ add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
35
+ elseif (GGML_SYCL_TARGET STREQUAL "AMD")
36
+ # INFO: Allowed Sub_group_sizes are not consistent through all
37
+ # hip targets. For example, 64 is used for certain models, but the backend
38
+ # does not support it.
39
+ # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32)
40
+ add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
41
+ else()
42
+ add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
43
+ endif()
44
+
45
+ file(GLOB GGML_HEADERS_SYCL "*.hpp")
46
+ file(GLOB GGML_SOURCES_SYCL "*.cpp")
47
+ target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
48
+
49
+ find_package(DNNL)
50
+ message("-- DNNL found:" ${DNNL_FOUND})
51
+
52
+ if (GGML_SYCL_TARGET STREQUAL "INTEL")
53
+ add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
54
+ else()
55
+ add_compile_definitions(GGML_SYCL_DNNL=0)
56
+ endif()
57
+
58
+ if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
59
+ target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
60
+ endif()
61
+
62
+ if (WIN32)
63
+ find_package(IntelSYCL REQUIRED)
64
+ find_package(MKL REQUIRED)
65
+ target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
66
+ else()
67
+ if (GGML_SYCL_TARGET STREQUAL "INTEL")
68
+ target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
69
+ elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
70
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
71
+ add_compile_definitions(GGML_SYCL_NVIDIA)
72
+ target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl_blas_cublas)
73
+ elseif (GGML_SYCL_TARGET STREQUAL "AMD")
74
+ if (NOT GGML_SYCL_DEVICE_ARCH)
75
+ message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")
76
+ endif()
77
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=amdgcn-amd-amdhsa")
78
+ target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl)
79
+ endif()
80
+
81
+ if (GGML_SYCL_DEVICE_ARCH)
82
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}")
83
+ endif()
84
+ endif()
@@ -26,5 +26,8 @@
26
26
  #include "softmax.hpp"
27
27
  #include "tsembd.hpp"
28
28
  #include "im2col.hpp"
29
+ #include "wkv6.hpp"
30
+ #include "outprod.hpp"
31
+ #include "element_wise.hpp"
29
32
 
30
33
  #endif // GGML_SYCL_BACKEND_HPP
@@ -11,6 +11,7 @@
11
11
  //
12
12
 
13
13
  #include "common.hpp"
14
+ #include "ggml-impl.h"
14
15
 
15
16
  int get_current_device_id() {
16
17
  return dpct::dev_mgr::instance().current_device_id();
@@ -28,11 +29,7 @@ void* ggml_sycl_host_malloc(size_t size) try {
28
29
 
29
30
  if (err != 0) {
30
31
  // clear the error
31
- fprintf(
32
- stderr,
33
- "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
34
- size / 1024.0 / 1024.0,
35
- "syclGetErrorString is not supported");
32
+ GGML_LOG_ERROR("WARNING: failed to allocate %.2f MB of pinned memory: %s\n", size / 1024.0 / 1024.0, "syclGetErrorString is not supported");
36
33
  return nullptr;
37
34
  }
38
35
 
@@ -62,3 +59,37 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block
62
59
  }
63
60
  return sycl_down_blk_size;
64
61
  }
62
+
63
+ void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
64
+ const ggml_tensor *src1, ggml_tensor *dst,
65
+ const ggml_sycl_op_flatten_t op) try {
66
+
67
+ const bool use_src1 = src1 != nullptr;
68
+
69
+ GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
70
+ GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
71
+
72
+ // dd = data device
73
+ float * src0_ddf = (float *) src0->data;
74
+ float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
75
+ float * dst_ddf = (float *) dst->data;
76
+
77
+ ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
78
+ ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
79
+ ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
80
+
81
+ ggml_sycl_set_device(ctx.device);
82
+ queue_ptr main_stream = ctx.stream();
83
+ // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
84
+ // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
85
+
86
+ // do the computation
87
+ op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
88
+ // print_ggml_tensor("tensor", dst);
89
+ }
90
+ catch (sycl::exception const &exc) {
91
+
92
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
93
+ << ", line:" << __LINE__ << std::endl;
94
+ std::exit(1);
95
+ }
@@ -404,4 +404,263 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
404
404
 
405
405
  int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
406
406
 
407
+ typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
408
+ const ggml_tensor *src1,
409
+ ggml_tensor *dst, const float *src0_dd,
410
+ const float *src1_dd, float *dst_dd,
411
+ const queue_ptr &main_stream);
412
+
413
+ template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
414
+ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
415
+ int ne0, int ne1, int ne2, int ne3,
416
+ int ne10, int ne11, int ne12, int ne13,
417
+ /*int s0, */ int s1, int s2, int s3,
418
+ /*int s10,*/ int s11, int s12, int s13,
419
+ const sycl::nd_item<3> &item_ct1) {
420
+ const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
421
+ item_ct1.get_local_id(2);
422
+ const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
423
+ item_ct1.get_local_id(1));
424
+ const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
425
+ item_ct1.get_local_id(0)) /
426
+ ne3;
427
+ const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
428
+ item_ct1.get_local_id(0)) %
429
+ ne3;
430
+
431
+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
432
+ return;
433
+ }
434
+
435
+ const int i11 = i1 % ne11;
436
+ const int i12 = i2 % ne12;
437
+ const int i13 = i3 % ne13;
438
+
439
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
440
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
441
+ const size_t i_dst = i_src0;
442
+
443
+ const src0_t * src0_row = src0 + i_src0;
444
+ const src1_t * src1_row = src1 + i_src1;
445
+ dst_t * dst_row = dst + i_dst;
446
+
447
+ for (int i0 = i0s; i0 < ne0;
448
+ i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
449
+ const int i10 = i0 % ne10;
450
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
451
+ }
452
+ }
453
+
454
+ template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
455
+ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
456
+ int ne0, int ne1, int ne2, int ne3,
457
+ int ne10, int ne11, int ne12, int ne13,
458
+ /*int s0, */ int s1, int s2, int s3,
459
+ /*int s10,*/ int s11, int s12, int s13,
460
+ const sycl::nd_item<3> &item_ct1) {
461
+
462
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
463
+ item_ct1.get_local_id(2);
464
+
465
+ const int i3 = i/(ne2*ne1*ne0);
466
+ const int i2 = (i/(ne1*ne0)) % ne2;
467
+ const int i1 = (i/ne0) % ne1;
468
+ const int i0 = i % ne0;
469
+
470
+ if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
471
+ return;
472
+ }
473
+
474
+ const int i11 = i1 % ne11;
475
+ const int i12 = i2 % ne12;
476
+ const int i13 = i3 % ne13;
477
+
478
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
479
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
480
+ const size_t i_dst = i_src0;
481
+
482
+ const src0_t * src0_row = src0 + i_src0;
483
+ const src1_t * src1_row = src1 + i_src1;
484
+ dst_t * dst_row = dst + i_dst;
485
+
486
+ const int i10 = i0 % ne10;
487
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
488
+ }
489
+
490
+
491
+ template<float (*bin_op)(const float, const float)>
492
+ struct bin_bcast_sycl {
493
+ template <typename src0_t, typename src1_t, typename dst_t>
494
+ void operator()(ggml_backend_sycl_context & ctx,
495
+ const struct ggml_tensor *src0,
496
+ const struct ggml_tensor *src1, struct ggml_tensor *dst,
497
+ const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
498
+ queue_ptr stream) {
499
+
500
+ GGML_TENSOR_BINARY_OP_LOCALS
501
+
502
+ int nr0 = ne10/ne0;
503
+ int nr1 = ne11/ne1;
504
+ int nr2 = ne12/ne2;
505
+ int nr3 = ne13/ne3;
506
+
507
+ int nr[4] = { nr0, nr1, nr2, nr3 };
508
+
509
+ // collapse dimensions until first broadcast dimension
510
+ int64_t cne0[] = {ne0, ne1, ne2, ne3};
511
+ int64_t cne1[] = {ne10, ne11, ne12, ne13};
512
+ size_t cnb0[] = {nb0, nb1, nb2, nb3};
513
+ size_t cnb1[] = {nb10, nb11, nb12, nb13};
514
+ auto collapse = [](int64_t cne[]) {
515
+ cne[0] *= cne[1];
516
+ cne[1] = cne[2];
517
+ cne[2] = cne[3];
518
+ cne[3] = 1;
519
+ };
520
+
521
+ auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
522
+ cnb[1] *= cne[1];
523
+ cnb[2] *= cne[2];
524
+ cnb[3] *= cne[3];
525
+ };
526
+
527
+ for (int i = 0; i < 4; i++) {
528
+ if (nr[i] != 1) {
529
+ break;
530
+ }
531
+ if (i > 0) {
532
+ collapse_nb(cnb0, cne0);
533
+ collapse_nb(cnb1, cne1);
534
+ collapse(cne0);
535
+ collapse(cne1);
536
+ }
537
+ }
538
+ {
539
+ int64_t ne0 = cne0[0];
540
+ int64_t ne1 = cne0[1];
541
+ int64_t ne2 = cne0[2];
542
+ int64_t ne3 = cne0[3];
543
+
544
+ int64_t ne10 = cne1[0];
545
+ int64_t ne11 = cne1[1];
546
+ int64_t ne12 = cne1[2];
547
+ int64_t ne13 = cne1[3];
548
+
549
+ size_t nb0 = cnb0[0];
550
+ size_t nb1 = cnb0[1];
551
+ size_t nb2 = cnb0[2];
552
+ size_t nb3 = cnb0[3];
553
+
554
+ size_t nb10 = cnb1[0];
555
+ size_t nb11 = cnb1[1];
556
+ size_t nb12 = cnb1[2];
557
+ size_t nb13 = cnb1[3];
558
+
559
+ size_t s0 = nb0 / sizeof(dst_t);
560
+ size_t s1 = nb1 / sizeof(dst_t);
561
+ size_t s2 = nb2 / sizeof(dst_t);
562
+ size_t s3 = nb3 / sizeof(dst_t);
563
+
564
+ size_t s10 = nb10 / sizeof(src1_t);
565
+ size_t s11 = nb11 / sizeof(src1_t);
566
+ size_t s12 = nb12 / sizeof(src1_t);
567
+ size_t s13 = nb13 / sizeof(src1_t);
568
+
569
+ GGML_ASSERT(s0 == 1);
570
+ GGML_ASSERT(s10 == 1);
571
+
572
+ const int block_size = 128;
573
+
574
+ int64_t hne0 = std::max(ne0/2LL, 1LL);
575
+
576
+ sycl::range<3> block_dims(1, 1, 1);
577
+ block_dims[2] = std::min<unsigned int>(hne0, block_size);
578
+ block_dims[1] = std::min<unsigned int>(
579
+ ne1, block_size / (unsigned int)block_dims[2]);
580
+ block_dims[0] = std::min(
581
+ std::min<unsigned int>(
582
+ ne2 * ne3, block_size / (unsigned int)block_dims[2] /
583
+ (unsigned int)block_dims[1]),
584
+ 64U);
585
+
586
+ sycl::range<3> block_nums(
587
+ (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
588
+ (ne1 + block_dims[1] - 1) / block_dims[1],
589
+ (hne0 + block_dims[2] - 1) / block_dims[2]);
590
+
591
+ if (block_nums[0] > 65535) {
592
+ // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
593
+ int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
594
+ {
595
+ dpct::has_capability_or_fail(stream->get_device(),
596
+ {sycl::aspect::fp16});
597
+
598
+ stream->parallel_for(
599
+ sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
600
+ sycl::range<3>(1, 1, block_size),
601
+ sycl::range<3>(1, 1, block_size)),
602
+ [=](sycl::nd_item<3> item_ct1) {
603
+ k_bin_bcast_unravel<bin_op>(
604
+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
605
+ ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
606
+ s13, item_ct1);
607
+ });
608
+ }
609
+ } else {
610
+ /*
611
+ DPCT1049:16: The work-group size passed to the SYCL kernel may
612
+ exceed the limit. To get the device limit, query
613
+ info::device::max_work_group_size. Adjust the work-group size if
614
+ needed.
615
+ */
616
+ dpct::has_capability_or_fail(stream->get_device(),
617
+ {sycl::aspect::fp16});
618
+
619
+ stream->parallel_for(
620
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
621
+ [=](sycl::nd_item<3> item_ct1) {
622
+ k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
623
+ ne2, ne3, ne10, ne11, ne12, ne13,
624
+ s1, s2, s3, s11, s12, s13,
625
+ item_ct1);
626
+ });
627
+ }
628
+ }
629
+ GGML_UNUSED(ctx);
630
+ }
631
+ };
632
+
633
+ template <class op>
634
+ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
635
+ const ggml_tensor *src1, ggml_tensor *dst,
636
+ const float *src0_dd, const float *src1_dd,
637
+ float *dst_dd,
638
+ const queue_ptr &main_stream) {
639
+
640
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
641
+ op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
642
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
643
+ op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
644
+ (sycl::half *)dst_dd, main_stream);
645
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
646
+ op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
647
+ main_stream);
648
+ } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
649
+ op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
650
+ main_stream);
651
+ } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
652
+ op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
653
+ main_stream);
654
+ } else {
655
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
656
+ ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
657
+ GGML_ABORT("fatal error");
658
+ }
659
+ }
660
+
661
+
662
+ void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
663
+ const ggml_tensor *src1, ggml_tensor *dst,
664
+ const ggml_sycl_op_flatten_t op);
665
+
407
666
  #endif // GGML_SYCL_COMMON_HPP
@@ -47,7 +47,7 @@ static void concat_f32_dim1(const float *x, const float *y, float *dst,
47
47
  // operation
48
48
  int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
49
49
  item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
50
- if (item_ct1.get_group(1) < ne01) { // src0
50
+ if (item_ct1.get_group(1) < (size_t) ne01) { // src0
51
51
  int offset_src =
52
52
  nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * ne01;
53
53
  dst[offset_dst] = x[offset_src];
@@ -70,7 +70,7 @@ static void concat_f32_dim2(const float *x, const float *y, float *dst,
70
70
  // operation
71
71
  int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
72
72
  item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
73
- if (item_ct1.get_group(0) < ne02) { // src0
73
+ if (item_ct1.get_group(0) < (size_t) ne02) { // src0
74
74
  int offset_src = nidx + item_ct1.get_group(1) * ne0 +
75
75
  item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
76
76
  dst[offset_dst] = x[offset_src];
@@ -106,6 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
106
106
  concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
107
107
  });
108
108
  break;
109
+ // dim >=2 will be dispatched to the default path
109
110
  default:
110
111
  stream->parallel_for(
111
112
  sycl::nd_range<3>(gridDim *
@@ -424,7 +424,7 @@ static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y,
424
424
  const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
425
425
 
426
426
  // make each work-item deal with more elements since sycl global range can not exceed max int
427
- const src_t * x = (src_t *) vx;
427
+ const src_t * x = (const src_t *) vx;
428
428
  for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
429
429
  y[i] = x[i];
430
430
  }
@@ -1015,9 +1015,9 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
1015
1015
  break;
1016
1016
  }
1017
1017
 
1018
- (void) src1;
1019
- (void) dst;
1020
- (void) src1_ddq_i;
1021
- (void) src1_ncols;
1022
- (void) src1_padded_row_size;
1018
+ GGML_UNUSED(src1);
1019
+ GGML_UNUSED(dst);
1020
+ GGML_UNUSED(src1_ddq_i);
1021
+ GGML_UNUSED(src1_ncols);
1022
+ GGML_UNUSED(src1_padded_row_size);
1023
1023
  }
@@ -15,6 +15,7 @@
15
15
 
16
16
  #include <sycl/sycl.hpp>
17
17
  #include <sycl/half_type.hpp>
18
+ #include <syclcompat/math.hpp>
18
19
  #include <oneapi/mkl.hpp>
19
20
  #include <map>
20
21
 
@@ -1236,7 +1237,7 @@ namespace dpct
1236
1237
 
1237
1238
  std::map<byte_t *, allocation>::iterator get_map_iterator(const void *ptr)
1238
1239
  {
1239
- auto it = m_map.upper_bound((byte_t *)ptr);
1240
+ auto it = m_map.upper_bound(const_cast<byte_t *>(reinterpret_cast<const byte_t *>(ptr)));
1240
1241
  if (it == m_map.end())
1241
1242
  {
1242
1243
  // Not a virtual pointer.
@@ -1688,9 +1689,14 @@ namespace dpct
1688
1689
  auto data_a = get_memory<const Ta>(a);
1689
1690
  auto data_b = get_memory<const Tb>(b);
1690
1691
  auto data_c = get_memory<Tc>(c);
1691
- oneapi::mkl::blas::column_major::gemm(
1692
- q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1693
- data_b, ldb, beta_value, data_c, ldc);
1692
+ #ifdef GGML_SYCL_NVIDIA
1693
+ oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1694
+ a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1695
+ beta_value, data_c, ldc);
1696
+ #else
1697
+ oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1698
+ beta_value, data_c, ldc);
1699
+ #endif
1694
1700
  }
1695
1701
 
1696
1702
  template <typename VecT, class BinaryOperation, class = void>
@@ -1753,14 +1759,22 @@ namespace dpct
1753
1759
  matrix_info->ld_info[2] = ldc;
1754
1760
  matrix_info->groupsize_info = batch_size;
1755
1761
 
1762
+ #ifdef GGML_SYCL_NVIDIA
1763
+ sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1764
+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
1765
+ matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
1766
+ matrix_info->size_info + 2, matrix_info->value_info, reinterpret_cast<const Ta **>(a),
1767
+ matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1768
+ matrix_info->value_info + 1, reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1,
1769
+ &(matrix_info->groupsize_info));
1770
+ #else
1756
1771
  sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1757
- q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
1758
- matrix_info->size_info, matrix_info->size_info + 1,
1759
- matrix_info->size_info + 2, matrix_info->value_info,
1760
- reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
1761
- reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1762
- matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
1772
+ q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
1773
+ matrix_info->size_info + 1, matrix_info->size_info + 2, matrix_info->value_info,
1774
+ reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1775
+ matrix_info->ld_info + 1, matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
1763
1776
  matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1777
+ #endif
1764
1778
 
1765
1779
  q.submit([&](sycl::handler &cgh)
1766
1780
  {
@@ -1782,10 +1796,16 @@ namespace dpct
1782
1796
  auto data_a = get_memory<const Ta>(a);
1783
1797
  auto data_b = get_memory<const Tb>(b);
1784
1798
  auto data_c = get_memory<Tc>(c);
1799
+ #ifdef GGML_SYCL_NVIDIA
1785
1800
  oneapi::mkl::blas::column_major::gemm_batch(
1786
- q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1787
- stride_a, data_b, ldb, stride_b, beta_value,
1788
- data_c, ldc, stride_c, batch_size);
1801
+ oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
1802
+ alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
1803
+ batch_size);
1804
+ #else
1805
+ oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1806
+ stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
1807
+ stride_c, batch_size);
1808
+ #endif
1789
1809
  }
1790
1810
 
1791
1811
  } // namespace detail
@@ -1830,31 +1850,10 @@ namespace dpct
1830
1850
  : id);
1831
1851
  }
1832
1852
 
1833
- template <typename T>
1834
- sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val)
1835
- {
1836
- return sycl::vec<T, 1>(val)
1837
- .template as<sycl::vec<
1838
- std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>, 4>>()
1839
- .template convert<T>();
1840
- }
1841
-
1842
- template <typename T1, typename T2>
1843
- using dot_product_acc_t =
1844
- std::conditional_t<std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
1845
- uint32_t, int32_t>;
1846
-
1847
1853
  template <typename T1, typename T2, typename T3>
1848
1854
  inline auto dp4a(T1 a, T2 b, T3 c)
1849
1855
  {
1850
- dot_product_acc_t<T1, T2> res = c;
1851
- auto va = extract_and_sign_or_zero_extend4(a);
1852
- auto vb = extract_and_sign_or_zero_extend4(b);
1853
- res += va[0] * vb[0];
1854
- res += va[1] * vb[1];
1855
- res += va[2] * vb[2];
1856
- res += va[3] * vb[3];
1857
- return res;
1856
+ return syclcompat::dp4a(a, b, c);
1858
1857
  }
1859
1858
 
1860
1859
  struct sub_sat