@fugood/llama.node 0.3.15 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +243 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +14 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -8
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2413 -228
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1004 -13516
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +127 -33
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +29 -293
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +210 -286
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  136. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  138. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +692 -126
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +21 -10
  143. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  144. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  145. package/src/llama.cpp/include/llama.h +30 -11
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  147. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  149. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  150. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  151. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  152. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  153. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  154. package/src/llama.cpp/src/llama-arch.cpp +161 -17
  155. package/src/llama.cpp/src/llama-arch.h +16 -0
  156. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  157. package/src/llama.cpp/src/llama-chat.h +6 -2
  158. package/src/llama.cpp/src/llama-context.cpp +108 -92
  159. package/src/llama.cpp/src/llama-context.h +1 -2
  160. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  161. package/src/llama.cpp/src/llama-graph.h +26 -6
  162. package/src/llama.cpp/src/llama-hparams.h +13 -0
  163. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  164. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  165. package/src/llama.cpp/src/llama-memory.h +1 -1
  166. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  167. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  168. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  169. package/src/llama.cpp/src/llama-model.cpp +1544 -291
  170. package/src/llama.cpp/src/llama-model.h +13 -1
  171. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  172. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  173. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  174. package/src/llama.cpp/src/llama.cpp +1 -1
  175. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  176. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  177. package/src/llama.cpp/tests/test-backend-ops.cpp +139 -57
  178. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  179. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  180. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  181. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  182. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  183. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  184. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  185. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  186. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  188. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  189. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  190. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  191. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  192. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  193. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  203. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -37,6 +37,8 @@
37
37
  #include "ggml-backend-impl.h"
38
38
 
39
39
  #include "ggml-sycl/backend.hpp"
40
+ #include "ggml-sycl/common.hpp"
41
+ #include "ggml-sycl/element_wise.hpp"
40
42
  #include "ggml-sycl/presets.hpp"
41
43
  #include "ggml-sycl/gemm.hpp"
42
44
  #include "ggml-sycl/sycl_hw.hpp"
@@ -371,6 +373,8 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
371
373
  auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
372
374
  SYCL_CHECK(
373
375
  CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
376
+ // Note: Use host buffer to save the data from mmap(), then copy to device. It's workaround for mmap() issue on PVC GPU.
377
+ // This function will be called during load model from disk. Use memory buffer replace dynamic won't save more time and brings potential memory leak risk here.
374
378
  char* host_buf = (char*)malloc(size);
375
379
  memcpy(host_buf, data, size);
376
380
  SYCL_CHECK(
@@ -490,6 +494,23 @@ catch (sycl::exception const &exc) {
490
494
  std::exit(1);
491
495
  }
492
496
 
497
+ static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
498
+ size_t offset, size_t size) {
499
+ GGML_SYCL_DEBUG(" [SYCL] call %s\n", __func__);
500
+ ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
501
+ SYCL_CHECK(ggml_sycl_set_device(ctx->device));
502
+ auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
503
+ if (size == 0) {
504
+ return; // Nothing to do
505
+ }
506
+ if (tensor->data == nullptr) {
507
+ GGML_ABORT("Error: Tensor data pointer is null.\n");
508
+ }
509
+ void * target_ptr = static_cast<char *>(tensor->data) + offset;
510
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size)));
511
+ SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait()));
512
+ }
513
+
493
514
  static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
494
515
  GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
495
516
  if (buffer == nullptr) {
@@ -510,7 +531,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
510
531
  /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
511
532
  /* .get_base = */ ggml_backend_sycl_buffer_get_base,
512
533
  /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
513
- /* .memset_tensor = */ NULL,
534
+ /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor,
514
535
  /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
515
536
  /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
516
537
  /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
@@ -1597,17 +1618,6 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
1597
1618
  dst[i] = scale * x[i];
1598
1619
  }
1599
1620
 
1600
- static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
1601
- const sycl::nd_item<3> &item_ct1) {
1602
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1603
- item_ct1.get_local_id(2);
1604
-
1605
- if (i >= k) {
1606
- return;
1607
- }
1608
-
1609
- dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
1610
- }
1611
1621
 
1612
1622
  template <typename Ti, typename To>
1613
1623
  static void pool2d_nchw_kernel(
@@ -1748,18 +1758,6 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
1748
1758
  });
1749
1759
  }
1750
1760
 
1751
- static void clamp_f32_sycl(const float *x, float *dst, const float min,
1752
- const float max, const int k,
1753
- queue_ptr stream) {
1754
- const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
1755
- stream->parallel_for(
1756
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
1757
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
1758
- sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
1759
- [=](sycl::nd_item<3> item_ct1) {
1760
- clamp_f32(x, dst, min, max, k, item_ct1);
1761
- });
1762
- }
1763
1761
 
1764
1762
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
1765
1763
  const int nrows, queue_ptr stream) {
@@ -1970,19 +1968,6 @@ catch (sycl::exception const &exc) {
1970
1968
  std::exit(1);
1971
1969
  }
1972
1970
 
1973
- static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
1974
- const ggml_tensor *src1, ggml_tensor *dst,
1975
- const float *src0_d, const float *src1_d,
1976
- float *dst_d,
1977
- const queue_ptr &main_stream) {
1978
-
1979
- ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
1980
-
1981
- GGML_UNUSED(src1);
1982
- GGML_UNUSED(src1_d);
1983
- }
1984
-
1985
-
1986
1971
  inline void ggml_sycl_op_mul_mat_sycl(
1987
1972
  ggml_backend_sycl_context & ctx,
1988
1973
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -2049,8 +2034,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
2049
2034
  const sycl::half alpha_f16 = 1.0f;
2050
2035
  const sycl::half beta_f16 = 0.0f;
2051
2036
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2052
- *stream, oneapi::mkl::transpose::trans,
2053
- oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2037
+ *stream, oneapi::math::transpose::trans,
2038
+ oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
2054
2039
  &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2055
2040
  src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2056
2041
  dst_f16.get(), dpct::library_data_t::real_half, ldc,
@@ -2058,9 +2043,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
2058
2043
  const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2059
2044
  to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2060
2045
  #else
2061
- auto dnnl_stream = ctx.stream_dnnl(stream);
2062
- DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2063
- src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
2046
+ DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
2047
+ DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2048
+ dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
2064
2049
  const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2065
2050
  to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2066
2051
  #endif
@@ -2087,21 +2072,14 @@ inline void ggml_sycl_op_mul_mat_sycl(
2087
2072
  #if !GGML_SYCL_DNNL
2088
2073
  const float alpha = 1.0f;
2089
2074
  const float beta = 0.0f;
2090
- # ifdef GGML_SYCL_NVIDIA
2091
- SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2092
- oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
2093
- oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
2094
- ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2095
- # else
2096
- SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
2097
- *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
2098
- dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
2099
- dst_dd_i, ldc)));
2100
- # endif
2075
+ SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
2076
+ get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
2077
+ src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
2078
+ dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2101
2079
  #else
2102
- auto dnnl_stream = ctx.stream_dnnl(stream);
2103
- DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
2104
- src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
2080
+ DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
2081
+ DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2082
+ dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2105
2083
  #endif
2106
2084
  }
2107
2085
  GGML_UNUSED(dst);
@@ -2114,13 +2092,14 @@ catch (sycl::exception const &exc) {
2114
2092
  std::exit(1);
2115
2093
  }
2116
2094
 
2117
- static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2118
- const ggml_tensor *src1, ggml_tensor *dst,
2119
- const float *src0_dd, const float *src1_dd,
2120
- float *dst_dd, const queue_ptr &main_stream) {
2095
+ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2121
2096
 
2122
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2097
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2123
2098
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2099
+ dpct::queue_ptr main_stream = ctx.stream();
2100
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2101
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2102
+ float * dst_dd = static_cast<float *>(dst->data);
2124
2103
 
2125
2104
  const int32_t * opts = (const int32_t *)dst->op_params;
2126
2105
  enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
@@ -2131,8 +2110,8 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
2131
2110
  const int p0 = opts[5];
2132
2111
  const int p1 = opts[6];
2133
2112
 
2134
- const int64_t IH = src0->ne[1];
2135
- const int64_t IW = src0->ne[0];
2113
+ const int64_t IH = dst->src[0]->ne[1];
2114
+ const int64_t IW = dst->src[0]->ne[0];
2136
2115
 
2137
2116
  const int64_t N = dst->ne[3];
2138
2117
  const int64_t OC = dst->ne[2];
@@ -2151,163 +2130,105 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
2151
2130
  parallel_elements, src0_dd, dst_dd, op,
2152
2131
  item_ct1);
2153
2132
  });
2154
-
2155
- GGML_UNUSED(src1);
2156
- GGML_UNUSED(src1_dd);
2157
- GGML_UNUSED(ctx);
2158
2133
  }
2159
2134
 
2160
- inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2161
- const ggml_tensor *src1, ggml_tensor *dst,
2162
- const float *src0_dd, const float *src1_dd,
2163
- float *dst_dd,
2164
- const queue_ptr &main_stream) {
2165
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2135
+ inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2136
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2166
2137
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2138
+ dpct::queue_ptr main_stream = ctx.stream();
2139
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2140
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2141
+ float * dst_dd = static_cast<float *>(dst->data);
2167
2142
 
2168
- const int64_t ne = ggml_nelements(src0);
2143
+ const int64_t ne = ggml_nelements(dst->src[0]);
2169
2144
 
2170
2145
  sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
2171
-
2172
- GGML_UNUSED(src1);
2173
- GGML_UNUSED(dst);
2174
- GGML_UNUSED(src1_dd);
2175
- GGML_UNUSED(ctx);
2176
2146
  }
2177
2147
 
2178
- inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2179
- const ggml_tensor *src1, ggml_tensor *dst,
2180
- const float *src0_dd, const float *src1_dd,
2181
- float *dst_dd,
2182
- const queue_ptr &main_stream) {
2148
+ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2183
2149
 
2184
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2150
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2185
2151
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2152
+ dpct::queue_ptr main_stream = ctx.stream();
2153
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2154
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2155
+ float * dst_dd = static_cast<float *>(dst->data);
2186
2156
 
2187
- const int64_t ncols = src0->ne[0];
2188
- const int64_t nrows = ggml_nrows(src0);
2157
+ const int64_t ncols = dst->src[0]->ne[0];
2158
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2189
2159
 
2190
2160
  sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2191
-
2192
- GGML_UNUSED(src1);
2193
- GGML_UNUSED(dst);
2194
- GGML_UNUSED(src1_dd);
2195
- GGML_UNUSED(ctx);
2196
2161
  }
2197
2162
 
2198
- inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2199
- const ggml_tensor *src1, ggml_tensor *dst,
2200
- const float *src0_dd, const float *src1_dd,
2201
- float *dst_dd,
2202
- const queue_ptr &main_stream) {
2163
+ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2164
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2165
+ GGML_ASSERT(dst->type == GGML_TYPE_I32);
2166
+ dpct::queue_ptr main_stream = ctx.stream();
2167
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2168
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2169
+ int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2203
2170
 
2204
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2205
- GGML_ASSERT( dst->type == GGML_TYPE_I32);
2206
2171
 
2207
- const int64_t ncols = src0->ne[0];
2208
- const int64_t nrows = ggml_nrows(src0);
2172
+ const int64_t ncols = dst->src[0]->ne[0];
2173
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2209
2174
 
2210
2175
  enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2211
2176
 
2212
- argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
2213
-
2214
- GGML_UNUSED(src1);
2215
- GGML_UNUSED(dst);
2216
- GGML_UNUSED(src1_dd);
2217
- GGML_UNUSED(ctx);
2177
+ argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
2218
2178
  }
2219
2179
 
2220
- inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2221
- const ggml_tensor *src1, ggml_tensor *dst,
2222
- const float *src0_dd, const float *src1_dd,
2223
- float *dst_dd,
2224
- const queue_ptr &main_stream) {
2180
+ inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2225
2181
 
2226
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2182
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2227
2183
  GGML_ASSERT( dst->type == GGML_TYPE_I32);
2228
2184
 
2229
- const int64_t ncols = src0->ne[0];
2230
- const int64_t nrows = ggml_nrows(src0);
2185
+ dpct::queue_ptr main_stream = ctx.stream();
2186
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2187
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2188
+ int32_t * dst_dd = static_cast<int32_t *>(dst->data);
2231
2189
 
2232
- argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream);
2190
+ const int64_t ncols = dst->src[0]->ne[0];
2191
+ const int64_t nrows = ggml_nrows(dst->src[0]);
2233
2192
 
2234
- GGML_UNUSED(src1);
2235
- GGML_UNUSED(dst);
2236
- GGML_UNUSED(src1_dd);
2237
- GGML_UNUSED(ctx);
2193
+ argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2238
2194
  }
2239
2195
 
2240
- inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2241
- const ggml_tensor *src1,
2242
- ggml_tensor *dst, const float *src0_dd,
2243
- const float *src1_dd, float *dst_dd,
2244
- const queue_ptr &main_stream) {
2196
+ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx,ggml_tensor *dst) {
2245
2197
 
2246
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2198
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2247
2199
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2200
+ dpct::queue_ptr main_stream = ctx.stream();
2201
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2202
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2203
+ float * dst_dd = static_cast<float *>(dst->data);
2248
2204
 
2249
- const int64_t ne00 = src0->ne[0];
2250
- const int64_t ne01 = src0->ne[1];
2251
- const int nrows0 = ggml_nrows(src0);
2205
+ const int64_t ne00 = dst->src[0]->ne[0];
2206
+ const int64_t ne01 = dst->src[0]->ne[1];
2207
+ const int nrows0 = ggml_nrows(dst->src[0]);
2252
2208
 
2253
2209
  const int n_past = ((int32_t *) dst->op_params)[0];
2254
2210
 
2255
2211
  diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
2256
-
2257
- GGML_UNUSED(src1);
2258
- GGML_UNUSED(dst);
2259
- GGML_UNUSED(src1_dd);
2260
- GGML_UNUSED(ctx);
2261
2212
  }
2262
2213
 
2263
- inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
2264
- ggml_tensor *dst, const float *src0_dd,
2265
- const float *src1_dd, float *dst_dd,
2266
- const queue_ptr &main_stream) {
2214
+ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
2267
2215
 
2268
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2216
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2269
2217
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
2218
+ dpct::queue_ptr main_stream = ctx.stream();
2219
+ SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2220
+ const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2221
+ float * dst_dd = static_cast<float *>(dst->data);
2270
2222
 
2271
2223
  float scale;
2272
2224
  memcpy(&scale, dst->op_params, sizeof(float));
2273
2225
 
2274
- scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
2226
+ scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
2275
2227
  /*
2276
2228
  DPCT1010:87: SYCL uses exceptions to report errors and does not use the
2277
2229
  error codes. The call was replaced with 0. You need to rewrite this code.
2278
2230
  */
2279
2231
  SYCL_CHECK(0);
2280
-
2281
- GGML_UNUSED(src1);
2282
- GGML_UNUSED(dst);
2283
- GGML_UNUSED(src1_dd);
2284
- GGML_UNUSED(ctx);
2285
- }
2286
-
2287
- inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
2288
- ggml_tensor *dst, const float *src0_dd,
2289
- const float *src1_dd, float *dst_dd,
2290
- const queue_ptr &main_stream) {
2291
-
2292
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
2293
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
2294
-
2295
- float min;
2296
- float max;
2297
- memcpy(&min, dst->op_params, sizeof(float));
2298
- memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
2299
-
2300
- clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
2301
- /*
2302
- DPCT1010:88: SYCL uses exceptions to report errors and does not use the
2303
- error codes. The call was replaced with 0. You need to rewrite this code.
2304
- */
2305
- SYCL_CHECK(0);
2306
-
2307
- GGML_UNUSED(src1);
2308
- GGML_UNUSED(dst);
2309
- GGML_UNUSED(src1_dd);
2310
- GGML_UNUSED(ctx);
2311
2232
  }
2312
2233
 
2313
2234
  static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
@@ -2675,39 +2596,33 @@ catch (sycl::exception const &exc) {
2675
2596
  }
2676
2597
 
2677
2598
 
2678
- static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2679
- GGML_SYCL_DEBUG("call %s\n", __func__);
2680
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_repeat);
2681
- GGML_SYCL_DEBUG("call %s done\n", __func__);
2682
- }
2683
-
2684
2599
  static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2685
2600
  GGML_SYCL_DEBUG("call %s\n", __func__);
2686
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_get_rows);
2601
+ ggml_sycl_op_get_rows(ctx, dst);
2687
2602
  GGML_SYCL_DEBUG("call %s done\n", __func__);
2688
2603
  }
2689
2604
 
2690
2605
  static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2691
2606
  GGML_SYCL_DEBUG("call %s\n", __func__);
2692
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_norm);
2607
+ ggml_sycl_op_norm(ctx, dst);
2693
2608
  GGML_SYCL_DEBUG("call %s done\n", __func__);
2694
2609
  }
2695
2610
 
2696
2611
  static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2697
2612
  GGML_SYCL_DEBUG("call %s\n", __func__);
2698
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rms_norm);
2613
+ ggml_sycl_op_rms_norm(ctx, dst);
2699
2614
  GGML_SYCL_DEBUG("call %s done\n", __func__);
2700
2615
  }
2701
2616
 
2702
2617
  static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2703
2618
  GGML_SYCL_DEBUG("call %s\n", __func__);
2704
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
2619
+ ggml_sycl_op_l2_norm(ctx, dst);
2705
2620
  GGML_SYCL_DEBUG("call %s done\n", __func__);
2706
2621
  }
2707
2622
 
2708
2623
  static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2709
2624
  GGML_SYCL_DEBUG("call %s\n", __func__);
2710
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
2625
+ ggml_sycl_op_group_norm(ctx, dst);
2711
2626
  GGML_SYCL_DEBUG("call %s done\n", __func__);
2712
2627
  }
2713
2628
 
@@ -2863,14 +2778,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
2863
2778
  if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2864
2779
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2865
2780
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2866
- *main_stream, oneapi::mkl::transpose::trans,
2867
- oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
2868
- (const char *)src0_as_f16, dpct::library_data_t::real_half,
2869
- nb01 / nb00, nb02 / nb00,
2870
- (const char *)src1_f16, dpct::library_data_t::real_half,
2871
- nb11 / nb10, nb12 / nb10, beta,
2872
- (char *)dst_t, cu_data_type, ne01, nb2 / nb0,
2873
- ne12 * ne13, cu_compute_type)));
2781
+ *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2782
+ (const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2783
+ (const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
2784
+ cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
2874
2785
  } else {
2875
2786
  const int ne23 = ne12*ne13;
2876
2787
 
@@ -2905,7 +2816,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
2905
2816
  });
2906
2817
  }
2907
2818
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2908
- *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
2819
+ *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2909
2820
  (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
2910
2821
  (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
2911
2822
  (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
@@ -2942,6 +2853,64 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2942
2853
  }
2943
2854
  }
2944
2855
 
2856
+ static void reorder_qw(char *data_device, const int ncols, const int nrows,
2857
+ size_t size, size_t offset, dpct::queue_ptr stream) {
2858
+ auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
2859
+ SYCL_CHECK(
2860
+ CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
2861
+ .wait()));
2862
+ GGML_ASSERT((size % sizeof(block_q4_0) == 0));
2863
+ GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
2864
+ int offset_blks = offset / sizeof(block_q4_0);
2865
+ auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;;
2866
+ auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
2867
+
2868
+ stream->parallel_for(
2869
+ size / sizeof(block_q4_0),
2870
+ [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
2871
+ const block_q4_0* x = (const block_q4_0*)tmp_buf;
2872
+ const int ib = i;
2873
+
2874
+ for (int j = 0; j < QK4_0/2; j ++)
2875
+ {
2876
+ *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
2877
+ }
2878
+ *(d_ptr + ib) = x[ib].d;
2879
+ });
2880
+
2881
+ sycl::free(tmp_buf, *stream);
2882
+ }
2883
+
2884
+ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
2885
+ char*data_device = (char*)src0->data;
2886
+ size_t ncols = src0->ne[0];
2887
+ size_t nrows = src0->ne[1];
2888
+ size_t size = ggml_nbytes(src0);
2889
+
2890
+ reorder_qw(data_device, ncols, nrows, size, 0, stream);
2891
+ }
2892
+
2893
+ /*
2894
+ * This function could be called when the OP (mul_mat) function support reorder optimizition.
2895
+ */
2896
+ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1,
2897
+ ggml_tensor * dst) {
2898
+ if (!g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT
2899
+ ctx->opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf.
2900
+ dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases.
2901
+ src0->type == GGML_TYPE_Q4_0 &&
2902
+ src1->ne[2]==1 && src1->ne[3]==1) {
2903
+
2904
+ ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
2905
+ if (!extra) return; //only happen in CI/UT permute case.
2906
+
2907
+ if (extra->optimized_feature.reorder) return; //skip the tensor which is handled for reorder.
2908
+
2909
+ reorder_qw(src0, ctx->stream());
2910
+ extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.
2911
+ }
2912
+ }
2913
+
2945
2914
  static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2946
2915
 
2947
2916
  const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
@@ -3004,6 +2973,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3004
2973
  // KQ + KQV multi-batch
3005
2974
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3006
2975
  } else if (use_dequantize_mul_mat_vec) {
2976
+ opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder.
3007
2977
  ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
3008
2978
  // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
3009
2979
  } else if (use_mul_mat_vec_q) {
@@ -3011,6 +2981,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3011
2981
  } else if (use_mul_mat_q) {
3012
2982
  ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
3013
2983
  } else {
2984
+ opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder.
3014
2985
  ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
3015
2986
  }
3016
2987
  }
@@ -3251,48 +3222,39 @@ catch (sycl::exception const &exc) {
3251
3222
  }
3252
3223
 
3253
3224
  static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3254
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_scale);
3255
- }
3256
-
3257
- static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3258
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp);
3225
+ ggml_sycl_op_scale(ctx, dst);
3259
3226
  }
3260
3227
 
3261
3228
  static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3262
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
3263
- }
3264
-
3265
- static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3266
- GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
3267
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
3229
+ ggml_sycl_op_diag_mask_inf(ctx, dst);
3268
3230
  }
3269
3231
 
3270
3232
  static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3271
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pool2d);
3233
+ ggml_sycl_op_pool2d(ctx, dst);
3272
3234
  }
3273
3235
 
3274
3236
  static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3275
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_im2col);
3237
+ ggml_sycl_op_im2col(ctx, dst);
3276
3238
  }
3277
3239
 
3278
3240
  static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3279
3241
  GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3280
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum);
3242
+ ggml_sycl_op_sum(ctx, dst);
3281
3243
  }
3282
3244
 
3283
3245
  static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3284
3246
  GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3285
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum_rows);
3247
+ ggml_sycl_op_sum_rows(ctx, dst);
3286
3248
  }
3287
3249
 
3288
3250
  static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3289
3251
  GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3290
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argsort);
3252
+ ggml_sycl_op_argsort(ctx, dst);
3291
3253
  }
3292
3254
 
3293
3255
  static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3294
3256
  GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3295
- ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argmax);
3257
+ ggml_sycl_op_argmax(ctx, dst);
3296
3258
  }
3297
3259
 
3298
3260
 
@@ -3317,7 +3279,7 @@ catch (sycl::exception const &exc) {
3317
3279
  std::exit(1);
3318
3280
  }
3319
3281
 
3320
- static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
3282
+ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try {
3321
3283
  if (!g_sycl_loaded) return false;
3322
3284
 
3323
3285
  if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
@@ -3394,6 +3356,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3394
3356
  case GGML_UNARY_OP_EXP:
3395
3357
  ggml_sycl_exp(ctx, dst);
3396
3358
  break;
3359
+ case GGML_UNARY_OP_SGN:
3360
+ ggml_sycl_sgn(ctx, dst);
3361
+ break;
3362
+ case GGML_UNARY_OP_ABS:
3363
+ ggml_sycl_abs(ctx, dst);
3364
+ break;
3365
+ case GGML_UNARY_OP_ELU:
3366
+ ggml_sycl_elu(ctx, dst);
3367
+ break;
3397
3368
  default:
3398
3369
  return false;
3399
3370
  }
@@ -3510,6 +3481,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3510
3481
  }
3511
3482
 
3512
3483
  return true;
3484
+ } catch (sycl::exception & e) {
3485
+ std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
3486
+ std::exit(1);
3513
3487
  }
3514
3488
 
3515
3489
  GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,
@@ -3641,71 +3615,8 @@ catch (sycl::exception const &exc) {
3641
3615
  std::exit(1);
3642
3616
  }
3643
3617
 
3644
- static void reorder_qw(char *data_device, const int ncols, const int nrows,
3645
- size_t size, size_t offset, dpct::queue_ptr stream) {
3646
- auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
3647
- SYCL_CHECK(
3648
- CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
3649
- .wait()));
3650
- GGML_ASSERT((size % sizeof(block_q4_0) == 0));
3651
- GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
3652
- int offset_blks = offset / sizeof(block_q4_0);
3653
- auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;;
3654
- auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
3655
-
3656
- stream->parallel_for(
3657
- size / sizeof(block_q4_0),
3658
- [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
3659
- const block_q4_0* x = (const block_q4_0*)tmp_buf;
3660
- const int ib = i;
3661
-
3662
- for (int j = 0; j < QK4_0/2; j ++)
3663
- {
3664
- *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
3665
- }
3666
- *(d_ptr + ib) = x[ib].d;
3667
- });
3668
-
3669
- sycl::free(tmp_buf, *stream);
3670
- }
3671
-
3672
- static void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
3673
- char*data_device = (char*)src0->data;
3674
- size_t ncols = src0->ne[0];
3675
- size_t nrows = src0->ne[1];
3676
- size_t size = ggml_nbytes(src0);
3677
-
3678
- reorder_qw(data_device, ncols, nrows, size, 0, stream);
3679
- }
3680
-
3681
- static void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
3682
- ggml_tensor *src0 = dst->src[0];
3683
- ggml_tensor *src1 = dst->src[1];
3684
-
3685
- if (dst->op == GGML_OP_MUL_MAT && src0->type == GGML_TYPE_Q4_0 &&
3686
- src1->ne[2]==1 && src1->ne[3]==1) {
3687
- reorder_qw(src0, stream);
3688
- ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
3689
- GGML_ASSERT(extra);
3690
- extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.
3691
- }
3692
- }
3693
-
3694
- static void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
3695
- dpct::queue_ptr stream = ctx->stream();
3696
- if (ctx->optimized_graph) {
3697
- return;
3698
- }
3699
- ctx->optimized_graph = true;
3700
-
3701
- for (int i = 0; i < cgraph->n_nodes; i++) {
3702
- if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream);
3703
- }
3704
- }
3705
-
3706
3618
  static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
3707
3619
  ggml_sycl_set_main_device(sycl_ctx->device);
3708
- if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx);
3709
3620
 
3710
3621
  for (int i = 0; i < cgraph->n_nodes; i++) {
3711
3622
  ggml_tensor * node = cgraph->nodes[i];
@@ -3733,7 +3644,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
3733
3644
 
3734
3645
  #ifdef GGML_SYCL_GRAPH
3735
3646
  if (!g_ggml_sycl_disable_graph) {
3736
- if (!sycl_ctx->exec_graph && !dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph)) {
3647
+ const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph);
3648
+ if (!graph_support) {
3737
3649
  GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
3738
3650
  ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
3739
3651
  return GGML_STATUS_SUCCESS;
@@ -3744,8 +3656,10 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
3744
3656
  ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
3745
3657
  model_sycl_graph.end_recording();
3746
3658
 
3747
- if (!sycl_ctx->exec_graph) {
3748
- auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
3659
+ const bool graph_update_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph);
3660
+ if (!sycl_ctx->exec_graph || !graph_update_support) {
3661
+ auto exec_graph = graph_update_support ? model_sycl_graph.finalize(sycl_ex::property::graph::updatable{}) :
3662
+ model_sycl_graph.finalize();
3749
3663
  sycl_ctx->exec_graph = std::make_unique<
3750
3664
  sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
3751
3665
  } else {
@@ -3933,7 +3847,14 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
3933
3847
  case GGML_UNARY_OP_GELU_QUICK:
3934
3848
  case GGML_UNARY_OP_TANH:
3935
3849
  case GGML_UNARY_OP_EXP:
3936
- return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32);
3850
+ case GGML_UNARY_OP_SGN:
3851
+ case GGML_UNARY_OP_ABS:
3852
+ case GGML_UNARY_OP_ELU:
3853
+ #if defined (GGML_SYCL_F16)
3854
+ return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
3855
+ #else
3856
+ return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
3857
+ #endif
3937
3858
  default:
3938
3859
  return false;
3939
3860
  }
@@ -4045,7 +3966,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4045
3966
  case GGML_OP_ARGMAX:
4046
3967
  case GGML_OP_NONE:
4047
3968
  case GGML_OP_RESHAPE:
4048
- case GGML_OP_REPEAT:
4049
3969
  case GGML_OP_VIEW:
4050
3970
  case GGML_OP_PERMUTE:
4051
3971
  case GGML_OP_TRANSPOSE:
@@ -4055,13 +3975,19 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4055
3975
  case GGML_OP_SUB:
4056
3976
  case GGML_OP_MUL:
4057
3977
  case GGML_OP_DIV:
3978
+ case GGML_OP_REPEAT:
3979
+ return true;
4058
3980
  case GGML_OP_SQR:
4059
3981
  case GGML_OP_SQRT:
4060
3982
  case GGML_OP_SIN:
4061
3983
  case GGML_OP_COS:
4062
3984
  case GGML_OP_CLAMP:
4063
3985
  case GGML_OP_LOG:
4064
- return (op->src[0]->type == GGML_TYPE_F32);
3986
+ #if defined (GGML_SYCL_F16)
3987
+ return ((op->type == GGML_TYPE_F32 || op->type == GGML_SYCL_F16) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_SYCL_F16) && (op->type == op->src[0]->type));
3988
+ #else
3989
+ return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
3990
+ #endif
4065
3991
  case GGML_OP_NORM:
4066
3992
  case GGML_OP_RMS_NORM:
4067
3993
  case GGML_OP_L2_NORM:
@@ -4077,23 +4003,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4077
4003
  case GGML_OP_ROPE:
4078
4004
  {
4079
4005
  const int mode = ((const int32_t *) op->op_params)[2];
4080
- if (mode & GGML_ROPE_TYPE_MROPE) {
4006
+ // mode is not used as a bitmask in practice, the various rope type modes are independent implementations
4007
+ if (mode == GGML_ROPE_TYPE_MROPE) {
4081
4008
  return false;
4082
4009
  }
4083
- if (mode & GGML_ROPE_TYPE_VISION) {
4084
- return false;
4085
- }
4086
- return ggml_is_contiguous(op->src[0]);
4010
+ return true;
4087
4011
  }
4088
4012
  case GGML_OP_IM2COL:
4089
- // TODO: add support for the new F32 operations
4090
- return op->src[0]->type == GGML_TYPE_F16;
4013
+ return true;
4014
+ case GGML_OP_UPSCALE:
4015
+ return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
4091
4016
  case GGML_OP_POOL_2D:
4092
4017
  case GGML_OP_SUM:
4093
4018
  case GGML_OP_SUM_ROWS:
4094
4019
  case GGML_OP_ARGSORT:
4095
4020
  case GGML_OP_ACC:
4096
- case GGML_OP_UPSCALE:
4097
4021
  case GGML_OP_PAD:
4098
4022
  case GGML_OP_LEAKY_RELU:
4099
4023
  case GGML_OP_TIMESTEP_EMBEDDING: