@novastera-oss/llamarn 0.2.7 → 0.2.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/android/src/main/cpp/include/llama.h +8 -3
  2. package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
  3. package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
  4. package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
  5. package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
  6. package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
  7. package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
  8. package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
  9. package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
  10. package/cpp/LlamaCppModel.cpp +56 -22
  11. package/cpp/build-info.cpp +2 -2
  12. package/cpp/llama.cpp/CMakeLists.txt +1 -1
  13. package/cpp/llama.cpp/common/arg.cpp +7 -0
  14. package/cpp/llama.cpp/common/common.cpp +3 -0
  15. package/cpp/llama.cpp/common/common.h +1 -0
  16. package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
  17. package/cpp/llama.cpp/convert_hf_to_gguf.py +118 -20
  18. package/cpp/llama.cpp/ggml/CMakeLists.txt +1 -0
  19. package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
  20. package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
  21. package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -0
  22. package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
  23. package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +31 -2
  24. package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
  25. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
  26. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
  27. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
  28. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  29. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
  30. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
  31. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
  32. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
  33. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
  34. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
  35. package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
  36. package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
  37. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
  38. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +83 -102
  39. package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  40. package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
  41. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +192 -67
  42. package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
  43. package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
  44. package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
  45. package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +211 -33
  46. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
  47. package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
  48. package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +54 -29
  49. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  50. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  51. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  52. package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  53. package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  54. package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +84 -31
  55. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
  56. package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
  57. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
  58. package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
  59. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
  60. package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  61. package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
  62. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
  63. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +227 -41
  64. package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +362 -182
  65. package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
  66. package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +240 -535
  67. package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  68. package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
  69. package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
  70. package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
  71. package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
  72. package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
  73. package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  74. package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  75. package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
  76. package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
  77. package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +45 -54
  78. package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
  79. package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
  80. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
  81. package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
  82. package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
  83. package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
  84. package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
  85. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  86. package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  87. package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  88. package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
  89. package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  90. package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +57 -1
  91. package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  92. package/cpp/llama.cpp/ggml/src/ggml.c +69 -13
  93. package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
  94. package/cpp/llama.cpp/gguf-py/gguf/constants.py +76 -0
  95. package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +21 -0
  96. package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +64 -0
  97. package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
  98. package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
  99. package/cpp/llama.cpp/include/llama.h +8 -3
  100. package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
  101. package/cpp/llama.cpp/src/llama-arch.cpp +55 -0
  102. package/cpp/llama.cpp/src/llama-arch.h +18 -0
  103. package/cpp/llama.cpp/src/llama-batch.cpp +570 -359
  104. package/cpp/llama.cpp/src/llama-batch.h +98 -70
  105. package/cpp/llama.cpp/src/llama-chat.cpp +11 -6
  106. package/cpp/llama.cpp/src/llama-context.cpp +101 -107
  107. package/cpp/llama.cpp/src/llama-context.h +13 -13
  108. package/cpp/llama.cpp/src/llama-graph.cpp +199 -252
  109. package/cpp/llama.cpp/src/llama-graph.h +44 -32
  110. package/cpp/llama.cpp/src/llama-hparams.cpp +4 -0
  111. package/cpp/llama.cpp/src/llama-hparams.h +8 -0
  112. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +51 -53
  113. package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +19 -24
  114. package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +110 -104
  115. package/cpp/llama.cpp/src/llama-kv-cache-unified.h +17 -22
  116. package/cpp/llama.cpp/src/llama-kv-cells.h +35 -11
  117. package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +66 -67
  118. package/cpp/llama.cpp/src/llama-memory-hybrid.h +16 -21
  119. package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +69 -68
  120. package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
  121. package/cpp/llama.cpp/src/llama-memory.h +18 -22
  122. package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
  123. package/cpp/llama.cpp/src/llama-model.cpp +1006 -472
  124. package/cpp/llama.cpp/src/llama-model.h +22 -0
  125. package/cpp/llama.cpp/src/llama-quant.cpp +87 -5
  126. package/cpp/llama.cpp/src/llama-vocab.cpp +26 -3
  127. package/cpp/llama.cpp/src/llama-vocab.h +1 -0
  128. package/cpp/rn-utils.h +3 -0
  129. package/ios/include/common.h +1 -0
  130. package/ios/include/llama.h +8 -3
  131. package/ios/libs/llama.xcframework/Info.plist +19 -19
  132. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  133. package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  134. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  135. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
  136. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -3
  137. package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
  138. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  139. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  140. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  141. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  142. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  143. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  144. package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
  145. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  146. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  147. package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3744
  148. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
  149. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
  150. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -3
  151. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
  152. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
  153. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -3
  154. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
  155. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
  156. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
  157. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -3
  158. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
  159. package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
  160. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  161. package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4863
  162. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  163. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
  164. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -3
  165. package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
  166. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  167. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4834
  168. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3742
  169. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  170. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  171. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  172. package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
  173. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  174. package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4900
  175. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
  176. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
  177. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -3
  178. package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
  179. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
  180. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4871
  181. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3773
  182. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
  183. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
  184. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -3
  185. package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
  186. package/package.json +1 -1
@@ -413,7 +413,8 @@ static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, co
413
413
  {
414
414
  dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
415
415
 
416
- stream->parallel_for(
416
+ sycl_parallel_for(
417
+ stream,
417
418
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
418
419
  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
419
420
  [=](sycl::nd_item<3> item_ct1) {
@@ -431,7 +432,8 @@ static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, co
431
432
  {
432
433
  dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
433
434
 
434
- stream->parallel_for(
435
+ sycl_parallel_for(
436
+ stream,
435
437
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
436
438
  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
437
439
  [=](sycl::nd_item<3> item_ct1) {
@@ -449,7 +451,8 @@ static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, co
449
451
  {
450
452
  dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
451
453
 
452
- stream->parallel_for(
454
+ sycl_parallel_for(
455
+ stream,
453
456
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
454
457
  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
455
458
  [=](sycl::nd_item<3> item_ct1) {
@@ -465,11 +468,11 @@ static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, c
465
468
  const int nb12, const int nb13, queue_ptr stream) {
466
469
  GGML_ASSERT(ne % QK8_0 == 0);
467
470
  const int num_blocks = ne / QK8_0;
468
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
469
- [=](sycl::nd_item<3> item_ct1) {
470
- cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
471
- ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
472
- });
471
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
472
+ [=](sycl::nd_item<3> item_ct1) {
473
+ cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
474
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
475
+ });
473
476
  }
474
477
 
475
478
  static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -477,11 +480,11 @@ static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, c
477
480
  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
478
481
  const int nb12, const int nb13, queue_ptr stream) {
479
482
  const int num_blocks = ne;
480
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
481
- [=](sycl::nd_item<3> item_ct1) {
482
- cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
483
- ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
484
- });
483
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
484
+ [=](sycl::nd_item<3> item_ct1) {
485
+ cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
486
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
487
+ });
485
488
  }
486
489
 
487
490
  static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -490,11 +493,11 @@ static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, c
490
493
  const int nb12, const int nb13, queue_ptr stream) {
491
494
  GGML_ASSERT(ne % QK4_0 == 0);
492
495
  const int num_blocks = ne / QK4_0;
493
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
494
- [=](sycl::nd_item<3> item_ct1) {
495
- cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
496
- ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
497
- });
496
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
497
+ [=](sycl::nd_item<3> item_ct1) {
498
+ cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
499
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
500
+ });
498
501
  }
499
502
 
500
503
  static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -502,8 +505,9 @@ static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, c
502
505
  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
503
506
  const int nb12, const int nb13, queue_ptr stream) {
504
507
  const int num_blocks = ne;
505
- stream->parallel_for(
506
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
508
+ sycl_parallel_for(
509
+ stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
510
+ [=](sycl::nd_item<3> item_ct1) {
507
511
  cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
508
512
  nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
509
513
  item_ct1);
@@ -516,11 +520,11 @@ static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, c
516
520
  const int nb12, const int nb13, queue_ptr stream) {
517
521
  GGML_ASSERT(ne % QK4_1 == 0);
518
522
  const int num_blocks = ne / QK4_1;
519
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
520
- [=](sycl::nd_item<3> item_ct1) {
521
- cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
522
- ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
523
- });
523
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
524
+ [=](sycl::nd_item<3> item_ct1) {
525
+ cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
526
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
527
+ });
524
528
  }
525
529
 
526
530
  static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -528,8 +532,9 @@ static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, c
528
532
  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
529
533
  const int nb12, const int nb13, queue_ptr stream) {
530
534
  const int num_blocks = ne;
531
- stream->parallel_for(
532
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
535
+ sycl_parallel_for(
536
+ stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
537
+ [=](sycl::nd_item<3> item_ct1) {
533
538
  cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
534
539
  nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
535
540
  item_ct1);
@@ -542,11 +547,11 @@ static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, c
542
547
  const int nb12, const int nb13, queue_ptr stream) {
543
548
  GGML_ASSERT(ne % QK5_0 == 0);
544
549
  const int num_blocks = ne / QK5_0;
545
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
546
- [=](sycl::nd_item<3> item_ct1) {
547
- cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
548
- ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
549
- });
550
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
551
+ [=](sycl::nd_item<3> item_ct1) {
552
+ cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
553
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
554
+ });
550
555
  }
551
556
 
552
557
  static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -554,8 +559,9 @@ static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, c
554
559
  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
555
560
  const int nb12, const int nb13, queue_ptr stream) {
556
561
  const int num_blocks = ne;
557
- stream->parallel_for(
558
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
562
+ sycl_parallel_for(
563
+ stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
564
+ [=](sycl::nd_item<3> item_ct1) {
559
565
  cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
560
566
  nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
561
567
  item_ct1);
@@ -568,11 +574,11 @@ static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, c
568
574
  const int nb12, const int nb13, queue_ptr stream) {
569
575
  GGML_ASSERT(ne % QK5_1 == 0);
570
576
  const int num_blocks = ne / QK5_1;
571
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
572
- [=](sycl::nd_item<3> item_ct1) {
573
- cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
574
- ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
575
- });
577
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
578
+ [=](sycl::nd_item<3> item_ct1) {
579
+ cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
580
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
581
+ });
576
582
  }
577
583
 
578
584
  static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -580,8 +586,9 @@ static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, c
580
586
  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
581
587
  const int nb12, const int nb13, queue_ptr stream) {
582
588
  const int num_blocks = ne;
583
- stream->parallel_for(
584
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
589
+ sycl_parallel_for(
590
+ stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
591
+ [=](sycl::nd_item<3> item_ct1) {
585
592
  cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
586
593
  nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
587
594
  item_ct1);
@@ -594,11 +601,11 @@ static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne,
594
601
  const int nb12, const int nb13, queue_ptr stream) {
595
602
  GGML_ASSERT(ne % QK4_NL == 0);
596
603
  const int num_blocks = ne / QK4_NL;
597
- stream->parallel_for(
598
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
599
- cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
600
- ne12, nb10, nb11, nb12, nb13, item_ct1);
601
- });
604
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
605
+ [=](sycl::nd_item<3> item_ct1) {
606
+ cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
607
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
608
+ });
602
609
  }
603
610
 
604
611
  static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -609,7 +616,8 @@ static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, co
609
616
  {
610
617
  dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
611
618
 
612
- stream->parallel_for(
619
+ sycl_parallel_for(
620
+ stream,
613
621
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
614
622
  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
615
623
  [=](sycl::nd_item<3> item_ct1) {
@@ -628,7 +636,8 @@ static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, co
628
636
  // dpct::has_capability_or_fail(stream->get_device(),
629
637
  // {sycl::aspect::fp16});
630
638
 
631
- stream->parallel_for(
639
+ sycl_parallel_for(
640
+ stream,
632
641
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
633
642
  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
634
643
  [=](sycl::nd_item<3> item_ct1) {
@@ -647,7 +656,8 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co
647
656
  // dpct::has_capability_or_fail(stream->get_device(),
648
657
  // {sycl::aspect::fp16});
649
658
 
650
- stream->parallel_for(
659
+ sycl_parallel_for(
660
+ stream,
651
661
  sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
652
662
  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
653
663
  [=](sycl::nd_item<3> item_ct1) {
@@ -662,11 +672,13 @@ static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const
662
672
  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
663
673
  const int nb12, const int nb13, queue_ptr stream) {
664
674
  const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
665
- stream->parallel_for(
666
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
667
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
668
- cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
669
- });
675
+ sycl_parallel_for(stream,
676
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
677
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
678
+ [=](sycl::nd_item<3> item_ct1) {
679
+ cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
680
+ ne12, nb10, nb11, nb12, nb13, item_ct1);
681
+ });
670
682
  }
671
683
 
672
684
 
@@ -675,11 +687,13 @@ static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const
675
687
  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
676
688
  const int nb12, const int nb13, queue_ptr stream) {
677
689
  const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
678
- stream->parallel_for(
679
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
680
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
681
- cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
682
- });
690
+ sycl_parallel_for(stream,
691
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
692
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
693
+ [=](sycl::nd_item<3> item_ct1) {
694
+ cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
695
+ ne12, nb10, nb11, nb12, nb13, item_ct1);
696
+ });
683
697
  }
684
698
 
685
699
 
@@ -689,11 +703,13 @@ static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const
689
703
  const int nb12, const int nb13, queue_ptr stream) {
690
704
  const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
691
705
 
692
- stream->parallel_for(
693
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
694
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
695
- cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
696
- });
706
+ sycl_parallel_for(stream,
707
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
708
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
709
+ [=](sycl::nd_item<3> item_ct1) {
710
+ cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
711
+ ne12, nb10, nb11, nb12, nb13, item_ct1);
712
+ });
697
713
  }
698
714
 
699
715
 
@@ -702,10 +718,13 @@ static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const
702
718
  const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
703
719
  const int nb12, const int nb13, queue_ptr stream) {
704
720
  const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
705
- stream->parallel_for(
706
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
707
- cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
708
- });
721
+ sycl_parallel_for(stream,
722
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
723
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
724
+ [=](sycl::nd_item<3> item_ct1) {
725
+ cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
726
+ ne12, nb10, nb11, nb12, nb13, item_ct1);
727
+ });
709
728
  }
710
729
 
711
730
 
@@ -715,10 +734,13 @@ static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const
715
734
  const int nb12, const int nb13, queue_ptr stream) {
716
735
 
717
736
  const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
718
- stream->parallel_for(
719
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
720
- cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
721
- });
737
+ sycl_parallel_for(stream,
738
+ sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
739
+ sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
740
+ [=](sycl::nd_item<3> item_ct1) {
741
+ cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
742
+ ne12, nb10, nb11, nb12, nb13, item_ct1);
743
+ });
722
744
  }
723
745
 
724
746
  void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try {
@@ -208,12 +208,10 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
208
208
  dpct::has_capability_or_fail(stream->get_device(),
209
209
  {sycl::aspect::fp16});
210
210
 
211
- stream->parallel_for(
212
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
213
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
214
- dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
215
- nrows, item_ct1);
216
- });
211
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
212
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
213
+ dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, nrows, item_ct1);
214
+ });
217
215
  }
218
216
  }
219
217
 
@@ -877,12 +875,11 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
877
875
  dpct::has_capability_or_fail(stream->get_device(),
878
876
  {sycl::aspect::fp16});
879
877
 
880
- stream->parallel_for(
881
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
882
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
883
- dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
884
- vx, y, dst, ncols, nrows, item_ct1);
885
- });
878
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
879
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
880
+ dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(vx, y, dst, ncols,
881
+ nrows, item_ct1);
882
+ });
886
883
  }
887
884
  }
888
885
 
@@ -900,12 +897,10 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
900
897
  dpct::has_capability_or_fail(stream->get_device(),
901
898
  {sycl::aspect::fp16});
902
899
 
903
- stream->parallel_for(
904
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
905
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
906
- dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
907
- vx, y, dst, ncols, nrows, item_ct1);
908
- });
900
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
901
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
902
+ dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(vx, y, dst, ncols, nrows, item_ct1);
903
+ });
909
904
  }
910
905
  }
911
906
 
@@ -921,12 +916,10 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
921
916
  dpct::has_capability_or_fail(stream->get_device(),
922
917
  {sycl::aspect::fp16});
923
918
 
924
- stream->parallel_for(
925
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
926
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
927
- dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
928
- vx, y, dst, ncols, nrows, item_ct1);
929
- });
919
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
920
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
921
+ dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(vx, y, dst, ncols, nrows, item_ct1);
922
+ });
930
923
  }
931
924
  }
932
925
 
@@ -942,12 +935,10 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
942
935
  dpct::has_capability_or_fail(stream->get_device(),
943
936
  {sycl::aspect::fp16});
944
937
 
945
- stream->parallel_for(
946
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
947
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
948
- dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
949
- vx, y, dst, ncols, nrows, item_ct1);
950
- });
938
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
939
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
940
+ dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(vx, y, dst, ncols, nrows, item_ct1);
941
+ });
951
942
  }
952
943
  }
953
944
 
@@ -963,12 +954,10 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
963
954
  dpct::has_capability_or_fail(stream->get_device(),
964
955
  {sycl::aspect::fp16});
965
956
 
966
- stream->parallel_for(
967
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
968
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
969
- dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
970
- vx, y, dst, ncols, nrows, item_ct1);
971
- });
957
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
958
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
959
+ dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(vx, y, dst, ncols, nrows, item_ct1);
960
+ });
972
961
  }
973
962
  }
974
963
 
@@ -984,12 +973,10 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
984
973
  dpct::has_capability_or_fail(stream->get_device(),
985
974
  {sycl::aspect::fp16});
986
975
 
987
- stream->parallel_for(
988
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
989
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
990
- dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
991
- vx, y, dst, ncols, nrows, item_ct1);
992
- });
976
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
977
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
978
+ dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(vx, y, dst, ncols, nrows, item_ct1);
979
+ });
993
980
  }
994
981
  }
995
982
 
@@ -1002,11 +989,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
1002
989
  const int block_num_y = (nrows + ny - 1) / ny;
1003
990
  const sycl::range<3> block_nums(1, 1, block_num_y);
1004
991
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1005
- stream->parallel_for(
1006
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1007
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1008
- dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
1009
- });
992
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
993
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
994
+ dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
995
+ });
1010
996
  }
1011
997
 
1012
998
  static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
@@ -1018,11 +1004,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
1018
1004
  const int block_num_y = (nrows + ny - 1) / ny;
1019
1005
  const sycl::range<3> block_nums(1, 1, block_num_y);
1020
1006
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1021
- stream->parallel_for(
1022
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1023
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1024
- dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
1025
- });
1007
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
1008
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1009
+ dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
1010
+ });
1026
1011
  }
1027
1012
 
1028
1013
  static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
@@ -1034,11 +1019,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
1034
1019
  const int block_num_y = (nrows + ny - 1) / ny;
1035
1020
  const sycl::range<3> block_nums(1, 1, block_num_y);
1036
1021
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1037
- stream->parallel_for(
1038
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1039
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1040
- dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
1041
- });
1022
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
1023
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1024
+ dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
1025
+ });
1042
1026
  }
1043
1027
 
1044
1028
  static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
@@ -1047,11 +1031,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
1047
1031
  dpct::queue_ptr stream) {
1048
1032
  GGML_ASSERT(ncols % QK_K == 0);
1049
1033
  const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
1050
- stream->parallel_for(
1051
- sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
1052
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1053
- dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
1054
- });
1034
+ sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
1035
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1036
+ dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
1037
+ });
1055
1038
  }
1056
1039
 
1057
1040
  static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
@@ -1063,11 +1046,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
1063
1046
  const int block_num_y = (nrows + ny - 1) / ny;
1064
1047
  const sycl::range<3> block_nums(1, 1, block_num_y);
1065
1048
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1066
- stream->parallel_for(
1067
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1068
- [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1069
- dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
1070
- });
1049
+ sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
1050
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1051
+ dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
1052
+ });
1071
1053
  }
1072
1054
 
1073
1055
  void ggml_sycl_op_dequantize_mul_mat_vec(
@@ -13,10 +13,10 @@
13
13
  #ifndef GGML_SYCL_DPCT_HELPER_HPP
14
14
  #define GGML_SYCL_DPCT_HELPER_HPP
15
15
 
16
+ #include <map>
16
17
  #include <sycl/sycl.hpp>
17
18
  #include <sycl/half_type.hpp>
18
19
  #include <syclcompat/math.hpp>
19
- #include <map>
20
20
 
21
21
  #ifdef GGML_SYCL_USE_INTEL_ONEMKL
22
22
  #include <oneapi/mkl.hpp>
@@ -118,6 +118,36 @@ inline auto get_onemath_backend(sycl::queue& queue)
118
118
  #endif
119
119
  }
120
120
 
121
+ #ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
122
+ namespace syclex = sycl::ext::oneapi::experimental;
123
+ #endif
124
+
125
+ template <int NR, typename Func>
126
+ __dpct_inline__ void sycl_parallel_for(sycl::handler & cgh, sycl::nd_range<NR> nd_range, Func && func) {
127
+ #ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
128
+ syclex::nd_launch(cgh, nd_range, func);
129
+ #else
130
+ cgh.parallel_for(nd_range, func);
131
+ #endif
132
+ }
133
+
134
+ template <int NR, typename Func>
135
+ __dpct_inline__ void sycl_parallel_for(sycl::queue * q, sycl::nd_range<NR> nd_range, Func && func) {
136
+ #ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
137
+ syclex::nd_launch(*q, nd_range, func);
138
+ #else
139
+ q->parallel_for(nd_range, func);
140
+ #endif
141
+ }
142
+
143
+ template <typename Func> __dpct_inline__ void sycl_launch(sycl::queue * stream, Func && func) {
144
+ #ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
145
+ syclex::submit(*stream, func);
146
+ #else
147
+ stream->submit(func);
148
+ #endif
149
+ }
150
+
121
151
  namespace dpct
122
152
  {
123
153
  typedef sycl::queue *queue_ptr;