@fugood/llama.node 0.3.16 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +238 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +6 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -192
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1003 -13519
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +96 -22
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +2 -292
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  130. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +204 -280
  131. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  133. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  135. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  136. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  138. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +646 -114
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +17 -8
  142. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  143. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  144. package/src/llama.cpp/include/llama.h +30 -11
  145. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  147. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  149. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  150. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  151. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  152. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  153. package/src/llama.cpp/src/llama-arch.cpp +160 -17
  154. package/src/llama.cpp/src/llama-arch.h +16 -0
  155. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  156. package/src/llama.cpp/src/llama-chat.h +6 -2
  157. package/src/llama.cpp/src/llama-context.cpp +108 -92
  158. package/src/llama.cpp/src/llama-context.h +1 -2
  159. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  160. package/src/llama.cpp/src/llama-graph.h +26 -6
  161. package/src/llama.cpp/src/llama-hparams.h +13 -0
  162. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  163. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  164. package/src/llama.cpp/src/llama-memory.h +1 -1
  165. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  166. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  167. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  168. package/src/llama.cpp/src/llama-model.cpp +1760 -534
  169. package/src/llama.cpp/src/llama-model.h +13 -1
  170. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  171. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  172. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  173. package/src/llama.cpp/src/llama.cpp +1 -1
  174. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  175. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  176. package/src/llama.cpp/tests/test-backend-ops.cpp +82 -43
  177. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  178. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  179. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  180. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  181. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  182. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  183. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  184. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  185. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  186. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  188. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  189. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  190. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  191. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  192. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  193. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -313,7 +313,6 @@ struct ggml_backend_sycl_context {
313
313
  int device;
314
314
  std::string name;
315
315
  optimize_feature opt_feature;
316
- bool optimized_graph=false;
317
316
 
318
317
  queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } };
319
318
 
@@ -494,298 +493,9 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
494
493
 
495
494
  int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
496
495
 
497
- typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
498
- const ggml_tensor *src1,
499
- ggml_tensor *dst, const float *src0_dd,
500
- const float *src1_dd, float *dst_dd,
501
- const queue_ptr &main_stream);
502
-
503
- template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
504
- static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
505
- int ne0, int ne1, int ne2, int ne3,
506
- int ne10, int ne11, int ne12, int ne13,
507
- /*int s0, */ int s1, int s2, int s3,
508
- /*int s00,*/ int s01, int s02, int s03,
509
- /*int s10,*/ int s11, int s12, int s13,
510
- const sycl::nd_item<3> &item_ct1) {
511
- const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
512
- item_ct1.get_local_id(2);
513
- const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
514
- item_ct1.get_local_id(1));
515
- const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
516
- item_ct1.get_local_id(0)) /
517
- ne3;
518
- const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
519
- item_ct1.get_local_id(0)) %
520
- ne3;
521
-
522
- if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
523
- return;
524
- }
525
-
526
- const int i11 = i1 % ne11;
527
- const int i12 = i2 % ne12;
528
- const int i13 = i3 % ne13;
529
-
530
- const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
531
- const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
532
- const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
533
-
534
- const src0_t * src0_row = src0 + i_src0;
535
- const src1_t * src1_row = src1 + i_src1;
536
- dst_t * dst_row = dst + i_dst;
537
-
538
- for (int i0 = i0s; i0 < ne0;
539
- i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
540
- const int i10 = i0 % ne10;
541
- dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
542
- }
543
- }
544
-
545
- template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
546
- static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
547
- int ne0, int ne1, int ne2, int ne3,
548
- int ne10, int ne11, int ne12, int ne13,
549
- /*int s0, */ int s1, int s2, int s3,
550
- /*int s00,*/ int s01, int s02, int s03,
551
- /*int s10,*/ int s11, int s12, int s13,
552
- const sycl::nd_item<3> &item_ct1) {
553
-
554
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
555
- item_ct1.get_local_id(2);
556
-
557
- const int i3 = i/(ne2*ne1*ne0);
558
- const int i2 = (i/(ne1*ne0)) % ne2;
559
- const int i1 = (i/ne0) % ne1;
560
- const int i0 = i % ne0;
561
-
562
- if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
563
- return;
564
- }
565
-
566
- const int i11 = i1 % ne11;
567
- const int i12 = i2 % ne12;
568
- const int i13 = i3 % ne13;
569
-
570
- const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
571
- const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
572
- const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
573
-
574
- const src0_t * src0_row = src0 + i_src0;
575
- const src1_t * src1_row = src1 + i_src1;
576
- dst_t * dst_row = dst + i_dst;
577
-
578
- const int i10 = i0 % ne10;
579
- dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
580
- }
581
-
582
-
583
- template<float (*bin_op)(const float, const float)>
584
- struct bin_bcast_sycl {
585
- template <typename src0_t, typename src1_t, typename dst_t>
586
- void operator()(ggml_backend_sycl_context & ctx,
587
- const struct ggml_tensor *src0,
588
- const struct ggml_tensor *src1, struct ggml_tensor *dst,
589
- const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
590
- queue_ptr stream) {
591
-
592
- GGML_TENSOR_BINARY_OP_LOCALS
593
-
594
- int nr0 = ne10/ne0;
595
- int nr1 = ne11/ne1;
596
- int nr2 = ne12/ne2;
597
- int nr3 = ne13/ne3;
598
-
599
- int nr[4] = { nr0, nr1, nr2, nr3 };
600
-
601
- // collapse dimensions until first broadcast dimension
602
- int64_t cne[] = {ne0, ne1, ne2, ne3};
603
- int64_t cne0[] = {ne00, ne01, ne02, ne03};
604
- int64_t cne1[] = {ne10, ne11, ne12, ne13};
605
- size_t cnb[] = {nb0, nb1, nb2, nb3};
606
- size_t cnb0[] = {nb00, nb01, nb02, nb03};
607
- size_t cnb1[] = {nb10, nb11, nb12, nb13};
608
- auto collapse = [](int64_t cne[]) {
609
- cne[0] *= cne[1];
610
- cne[1] = cne[2];
611
- cne[2] = cne[3];
612
- cne[3] = 1;
613
- };
614
-
615
- auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
616
- cnb[1] *= cne[1];
617
- cnb[2] *= cne[2];
618
- cnb[3] *= cne[3];
619
- };
620
-
621
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
622
- for (int i = 0; i < 4; i++) {
623
- if (nr[i] != 1) {
624
- break;
625
- }
626
- if (i > 0) {
627
- collapse_nb(cnb, cne);
628
- collapse_nb(cnb0, cne0);
629
- collapse_nb(cnb1, cne1);
630
- collapse(cne);
631
- collapse(cne0);
632
- collapse(cne1);
633
- }
634
- }
635
- }
636
- {
637
- int64_t ne0 = cne[0];
638
- int64_t ne1 = cne[1];
639
- int64_t ne2 = cne[2];
640
- int64_t ne3 = cne[3];
641
-
642
- int64_t ne10 = cne1[0];
643
- int64_t ne11 = cne1[1];
644
- int64_t ne12 = cne1[2];
645
- int64_t ne13 = cne1[3];
646
-
647
- size_t nb0 = cnb[0];
648
- size_t nb1 = cnb[1];
649
- size_t nb2 = cnb[2];
650
- size_t nb3 = cnb[3];
651
-
652
- size_t nb00 = cnb0[0];
653
- size_t nb01 = cnb0[1];
654
- size_t nb02 = cnb0[2];
655
- size_t nb03 = cnb0[3];
656
-
657
- size_t nb10 = cnb1[0];
658
- size_t nb11 = cnb1[1];
659
- size_t nb12 = cnb1[2];
660
- size_t nb13 = cnb1[3];
661
-
662
- size_t s0 = nb0 / sizeof(dst_t);
663
- size_t s1 = nb1 / sizeof(dst_t);
664
- size_t s2 = nb2 / sizeof(dst_t);
665
- size_t s3 = nb3 / sizeof(dst_t);
666
-
667
- size_t s10 = nb10 / sizeof(src1_t);
668
- size_t s11 = nb11 / sizeof(src1_t);
669
- size_t s12 = nb12 / sizeof(src1_t);
670
- size_t s13 = nb13 / sizeof(src1_t);
671
-
672
- size_t s00 = nb00 / sizeof(src0_t);
673
- size_t s01 = nb01 / sizeof(src0_t);
674
- size_t s02 = nb02 / sizeof(src0_t);
675
- size_t s03 = nb03 / sizeof(src0_t);
676
-
677
- GGML_UNUSED(s00);
678
-
679
- GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
680
- GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
681
- GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
682
- GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
683
-
684
- GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
685
- GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
686
- GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
687
- GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
688
-
689
- GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
690
- GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
691
- GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
692
- GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
693
-
694
- GGML_ASSERT(s0 == 1);
695
- GGML_ASSERT(s10 == 1);
696
-
697
- const int block_size = 128;
698
-
699
- int64_t hne0 = std::max(ne0/2LL, 1LL);
700
-
701
- sycl::range<3> block_dims(1, 1, 1);
702
- block_dims[2] = std::min<unsigned int>(hne0, block_size);
703
- block_dims[1] = std::min<unsigned int>(
704
- ne1, block_size / (unsigned int)block_dims[2]);
705
- block_dims[0] = std::min(
706
- std::min<unsigned int>(
707
- ne2 * ne3, block_size / (unsigned int)block_dims[2] /
708
- (unsigned int)block_dims[1]),
709
- 64U);
710
-
711
- sycl::range<3> block_nums(
712
- (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
713
- (ne1 + block_dims[1] - 1) / block_dims[1],
714
- (hne0 + block_dims[2] - 1) / block_dims[2]);
715
-
716
- if (block_nums[0] > 65535) {
717
- // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
718
- int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
719
- {
720
- dpct::has_capability_or_fail(stream->get_device(),
721
- {sycl::aspect::fp16});
722
-
723
- stream->parallel_for(
724
- sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
725
- sycl::range<3>(1, 1, block_size),
726
- sycl::range<3>(1, 1, block_size)),
727
- [=](sycl::nd_item<3> item_ct1) {
728
- k_bin_bcast_unravel<bin_op>(
729
- src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
730
- ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
731
- s03, s11, s12, s13, item_ct1);
732
- });
733
- }
734
- } else {
735
- /*
736
- DPCT1049:16: The work-group size passed to the SYCL kernel may
737
- exceed the limit. To get the device limit, query
738
- info::device::max_work_group_size. Adjust the work-group size if
739
- needed.
740
- */
741
- dpct::has_capability_or_fail(stream->get_device(),
742
- {sycl::aspect::fp16});
743
-
744
- stream->parallel_for(
745
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
746
- [=](sycl::nd_item<3> item_ct1) {
747
- k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
748
- ne2, ne3, ne10, ne11, ne12, ne13,
749
- s1, s2, s3, s01, s02, s03, s11, s12, s13,
750
- item_ct1);
751
- });
752
- }
753
- }
754
- GGML_UNUSED(ctx);
755
- }
756
- };
757
-
758
- template <class op>
759
- inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
760
- const ggml_tensor *src1, ggml_tensor *dst,
761
- const float *src0_dd, const float *src1_dd,
762
- float *dst_dd,
763
- const queue_ptr &main_stream) {
764
-
765
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
766
- op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
767
- } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
768
- op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
769
- (sycl::half *)dst_dd, main_stream);
770
- } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
771
- op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
772
- main_stream);
773
- } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
774
- op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
775
- main_stream);
776
- } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
777
- op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
778
- main_stream);
779
- } else {
780
- fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
781
- ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
782
- GGML_ABORT("fatal error");
783
- }
496
+ constexpr size_t ceil_div(const size_t m, const size_t n) {
497
+ return (m + n - 1) / n;
784
498
  }
785
499
 
786
500
  bool gpu_has_xmx(sycl::device &dev);
787
-
788
- void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
789
- const ggml_tensor *src1, ggml_tensor *dst,
790
- const ggml_sycl_op_flatten_t op);
791
501
  #endif // GGML_SYCL_COMMON_HPP
@@ -16,9 +16,18 @@
16
16
  #include <sycl/sycl.hpp>
17
17
  #include <sycl/half_type.hpp>
18
18
  #include <syclcompat/math.hpp>
19
- #include <oneapi/mkl.hpp>
20
19
  #include <map>
21
20
 
21
+ #ifdef GGML_SYCL_USE_INTEL_ONEMKL
22
+ #include <oneapi/mkl.hpp>
23
+ // Allow to use the same namespace for Intel oneMKL and oneMath
24
+ namespace oneapi {
25
+ namespace math = mkl;
26
+ }
27
+ #else
28
+ #include <oneapi/math.hpp>
29
+ #endif
30
+
22
31
  #include "ggml.h"
23
32
 
24
33
  #if defined(__linux__)
@@ -83,13 +92,32 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
83
92
  }
84
93
 
85
94
  template <typename Ts> struct matrix_info_t {
86
- oneapi::mkl::transpose transpose_info[2];
95
+ oneapi::math::transpose transpose_info[2];
87
96
  Ts value_info[2];
88
97
  std::int64_t size_info[3];
89
98
  std::int64_t ld_info[3];
90
99
  std::int64_t groupsize_info;
91
100
  };
92
101
 
102
+ inline auto get_onemath_backend(sycl::queue& queue)
103
+ #if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
104
+ -> sycl::queue&
105
+ #endif
106
+ {
107
+ // If the backend is known at compile-time, use oneMath backend_selector to use
108
+ // compile-time dispatching and avoid the need to dlopen libraries. Otherwise
109
+ // fallback to runtime dispatching.
110
+ #if defined(GGML_SYCL_NVIDIA)
111
+ return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
112
+ #elif defined(GGML_SYCL_AMD)
113
+ return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
114
+ #elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
115
+ return queue;
116
+ #else
117
+ static_assert(false, "Unsupported backend");
118
+ #endif
119
+ }
120
+
93
121
  namespace dpct
94
122
  {
95
123
  typedef sycl::queue *queue_ptr;
@@ -1686,26 +1714,18 @@ namespace dpct
1686
1714
 
1687
1715
  namespace detail
1688
1716
  {
1689
- template <class Ta, class Tb, class Tc, class Ts>
1690
- inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
1691
- oneapi::mkl::transpose b_trans, int m, int n, int k,
1692
- const void *alpha, const void *a, int lda, const void *b,
1693
- int ldb, const void *beta, void *c, int ldc)
1694
- {
1695
- Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1696
- Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1697
- auto data_a = get_memory<const Ta>(a);
1698
- auto data_b = get_memory<const Tb>(b);
1699
- auto data_c = get_memory<Tc>(c);
1700
- #ifdef GGML_SYCL_NVIDIA
1701
- oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
1702
- a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1703
- beta_value, data_c, ldc);
1704
- #else
1705
- oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
1706
- beta_value, data_c, ldc);
1707
- #endif
1708
- }
1717
+ template <class Ta, class Tb, class Tc, class Ts>
1718
+ inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
1719
+ int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
1720
+ const void * beta, void * c, int ldc) {
1721
+ Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1722
+ Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1723
+ auto data_a = get_memory<const Ta>(a);
1724
+ auto data_b = get_memory<const Tb>(b);
1725
+ auto data_c = get_memory<Tc>(c);
1726
+ oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
1727
+ lda, data_b, ldb, beta_value, data_c, ldc);
1728
+ }
1709
1729
 
1710
1730
  template <typename VecT, class BinaryOperation, class = void>
1711
1731
  class vectorized_binary
@@ -1735,7 +1755,7 @@ namespace dpct
1735
1755
  };
1736
1756
 
1737
1757
  template <class Ta, class Tb, class Tc, class Ts>
1738
- inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
1758
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
1739
1759
  int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
1740
1760
  int ldb, const void * beta, void ** c, int ldc, int batch_size,
1741
1761
  matrix_info_t<float> * matrix_info) {
@@ -1754,48 +1774,28 @@ namespace dpct
1754
1774
  matrix_info->ld_info[2] = ldc;
1755
1775
  matrix_info->groupsize_info = batch_size;
1756
1776
 
1757
- #ifdef GGML_SYCL_NVIDIA
1758
- sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1759
- oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
1760
- matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
1761
- matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
1762
- reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1763
- matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
1764
- reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1765
- #else
1766
- sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
1767
- q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
1768
- matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
1769
- reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
1770
- matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
1771
- reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1772
- #endif
1777
+ sycl::event e = oneapi::math::blas::column_major::gemm_batch(
1778
+ get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
1779
+ matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
1780
+ reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
1781
+ reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
1782
+ reinterpret_cast<Ts *>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
1783
+ matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
1773
1784
  }
1774
1785
 
1775
1786
  template <class Ta, class Tb, class Tc, class Ts>
1776
- inline void
1777
- gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
1778
- oneapi::mkl::transpose b_trans, int m, int n,
1779
- int k, const void *alpha, const void *a, int lda,
1780
- long long int stride_a, const void *b, int ldb,
1781
- long long int stride_b, const void *beta, void *c,
1782
- int ldc, long long int stride_c, int batch_size)
1783
- {
1787
+ inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
1788
+ int m, int n, int k, const void * alpha, const void * a, int lda,
1789
+ long long int stride_a, const void * b, int ldb, long long int stride_b,
1790
+ const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
1784
1791
  Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
1785
1792
  Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
1786
1793
  auto data_a = get_memory<const Ta>(a);
1787
1794
  auto data_b = get_memory<const Tb>(b);
1788
1795
  auto data_c = get_memory<Tc>(c);
1789
- #ifdef GGML_SYCL_NVIDIA
1790
- oneapi::mkl::blas::column_major::gemm_batch(
1791
- oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
1792
- alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
1793
- batch_size);
1794
- #else
1795
- oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
1796
- stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
1797
- stride_c, batch_size);
1798
- #endif
1796
+ oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
1797
+ data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
1798
+ data_c, ldc, stride_c, batch_size);
1799
1799
  }
1800
1800
 
1801
1801
  } // namespace detail
@@ -2259,13 +2259,10 @@ namespace dpct
2259
2259
  sycl::range<3>(x, y, 1), direction);
2260
2260
  }
2261
2261
 
2262
- inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans,
2263
- oneapi::mkl::transpose b_trans, int m, int n, int k,
2264
- const void *alpha, const void *a, library_data_t a_type,
2265
- int lda, const void *b, library_data_t b_type, int ldb,
2266
- const void *beta, void *c, library_data_t c_type, int ldc,
2267
- library_data_t scaling_type)
2268
- {
2262
+ inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
2263
+ int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
2264
+ library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
2265
+ library_data_t scaling_type) {
2269
2266
  if (scaling_type == library_data_t::real_float &&
2270
2267
  c_type == library_data_t::complex_float)
2271
2268
  {
@@ -2329,9 +2326,8 @@ namespace dpct
2329
2326
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2330
2327
  library_data_t::real_float, library_data_t::real_float):
2331
2328
  {
2332
- detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
2333
- float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b,
2334
- ldb, beta, c, ldc);
2329
+ detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2330
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2335
2331
  break;
2336
2332
  }
2337
2333
  case detail::get_type_combination_id(
@@ -2369,8 +2365,7 @@ namespace dpct
2369
2365
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2370
2366
  library_data_t::real_bfloat16, library_data_t::real_float):
2371
2367
  {
2372
- detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2373
- oneapi::mkl::bfloat16, float>(
2368
+ detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2374
2369
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
2375
2370
  break;
2376
2371
  }
@@ -2390,7 +2385,7 @@ namespace dpct
2390
2385
  default:
2391
2386
  throw std::runtime_error("the combination of data type is unsupported");
2392
2387
  }
2393
- } // gemm()
2388
+ } // gemm()
2394
2389
 
2395
2390
  /// Computes a batch of matrix-matrix product with general matrices.
2396
2391
  /// \param [in] q The queue where the routine should be executed.
@@ -2412,7 +2407,7 @@ namespace dpct
2412
2407
  /// \param [in] ldc Leading dimension of C.
2413
2408
  /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2414
2409
  /// \param [in] scaling_type Data type of the scaling factors.
2415
- inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
2410
+ inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
2416
2411
  int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
2417
2412
  const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
2418
2413
  library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
@@ -2450,7 +2445,7 @@ namespace dpct
2450
2445
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2451
2446
  library_data_t::real_bfloat16, library_data_t::real_float):
2452
2447
  {
2453
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
2448
+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2454
2449
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2455
2450
  break;
2456
2451
  }
@@ -2458,7 +2453,7 @@ namespace dpct
2458
2453
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2459
2454
  library_data_t::real_float, library_data_t::real_float):
2460
2455
  {
2461
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
2456
+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2462
2457
  q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
2463
2458
  break;
2464
2459
  }
@@ -2534,15 +2529,11 @@ namespace dpct
2534
2529
  /// \param [in] stride_c Stride between the different C matrices.
2535
2530
  /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
2536
2531
  /// \param [in] scaling_type Data type of the scaling factors.
2537
- inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
2538
- oneapi::mkl::transpose b_trans, int m, int n, int k,
2539
- const void *alpha, const void *a, library_data_t a_type,
2540
- int lda, long long int stride_a, const void *b,
2541
- library_data_t b_type, int ldb, long long int stride_b,
2542
- const void *beta, void *c, library_data_t c_type,
2543
- int ldc, long long int stride_c, int batch_size,
2544
- library_data_t scaling_type)
2545
- {
2532
+ inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
2533
+ int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
2534
+ long long int stride_a, const void * b, library_data_t b_type, int ldb,
2535
+ long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
2536
+ long long int stride_c, int batch_size, library_data_t scaling_type) {
2546
2537
  if (scaling_type == library_data_t::real_float &&
2547
2538
  c_type == library_data_t::complex_float)
2548
2539
  {
@@ -2611,20 +2602,18 @@ namespace dpct
2611
2602
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2612
2603
  library_data_t::real_bfloat16, library_data_t::real_float):
2613
2604
  {
2614
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
2615
- oneapi::mkl::bfloat16, float>(
2616
- q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
2617
- beta, c, ldc, stride_c, batch_size);
2605
+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
2606
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2607
+ batch_size);
2618
2608
  break;
2619
2609
  }
2620
2610
  case detail::get_type_combination_id(
2621
2611
  library_data_t::real_bfloat16, library_data_t::real_bfloat16,
2622
2612
  library_data_t::real_float, library_data_t::real_float):
2623
2613
  {
2624
- detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
2625
- float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
2626
- stride_a, b, ldb, stride_b, beta, c, ldc,
2627
- stride_c, batch_size);
2614
+ detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
2615
+ q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
2616
+ batch_size);
2628
2617
  break;
2629
2618
  }
2630
2619
  #endif