llama_cpp 0.15.3 → 0.16.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/ext/llama_cpp/extconf.rb +1 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +27 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +15 -1
  7. data/vendor/tmp/llama.cpp/Makefile +66 -36
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
  9. data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
  10. data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
  11. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  12. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  13. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
  14. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda.cu +35 -16
  131. data/vendor/tmp/llama.cpp/ggml-impl.h +4 -0
  132. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -7
  133. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  134. data/vendor/tmp/llama.cpp/ggml-metal.m +99 -35
  135. data/vendor/tmp/llama.cpp/ggml-metal.metal +146 -80
  136. data/vendor/tmp/llama.cpp/ggml-quants.c +101 -11
  137. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +75 -58
  138. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +345 -227
  139. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
  140. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +458 -329
  141. data/vendor/tmp/llama.cpp/ggml.c +301 -409
  142. data/vendor/tmp/llama.cpp/ggml.h +19 -23
  143. data/vendor/tmp/llama.cpp/llama.cpp +855 -651
  144. data/vendor/tmp/llama.cpp/llama.h +28 -48
  145. metadata +121 -6
  146. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  147. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  148. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  149. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -2944,6 +2944,57 @@ namespace dpct
2944
2944
  using shared_memory = detail::device_memory<T, shared, Dimension>;
2945
2945
 
2946
2946
 
2947
+ template <typename T,
2948
+ sycl::access::address_space addressSpace =
2949
+ sycl::access::address_space::global_space,
2950
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2951
+ sycl::memory_scope memoryScope = sycl::memory_scope::device>
2952
+ inline T atomic_fetch_add(T *addr, T operand) {
2953
+ auto atm =
2954
+ sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
2955
+ return atm.fetch_add(operand);
2956
+ }
2957
+
2958
+ template <sycl::access::address_space addressSpace =
2959
+ sycl::access::address_space::global_space,
2960
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2961
+ sycl::memory_scope memoryScope = sycl::memory_scope::device,
2962
+ typename T1, typename T2>
2963
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
2964
+ auto atm =
2965
+ sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
2966
+ return atm.fetch_add(operand);
2967
+ }
2968
+
2969
+ template <typename T, sycl::access::address_space addressSpace =
2970
+ sycl::access::address_space::global_space>
2971
+ inline T atomic_fetch_add(T *addr, T operand,
2972
+ sycl::memory_order memoryOrder) {
2973
+ switch (memoryOrder) {
2974
+ case sycl::memory_order::relaxed:
2975
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
2976
+ sycl::memory_scope::device>(addr, operand);
2977
+ case sycl::memory_order::acq_rel:
2978
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
2979
+ sycl::memory_scope::device>(addr, operand);
2980
+ case sycl::memory_order::seq_cst:
2981
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
2982
+ sycl::memory_scope::device>(addr, operand);
2983
+ default:
2984
+ assert(false && "Invalid memory_order for atomics. Valid memory_order for "
2985
+ "atomics are: sycl::memory_order::relaxed, "
2986
+ "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
2987
+ }
2988
+ }
2989
+
2990
+ template <sycl::access::address_space addressSpace =
2991
+ sycl::access::address_space::global_space,
2992
+ typename T1, typename T2>
2993
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand,
2994
+ sycl::memory_order memoryOrder) {
2995
+ atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
2996
+ }
2997
+
2947
2998
  } // COPY from DPCT head files
2948
2999
 
2949
3000
  #define GGML_COMMON_DECL_SYCL
@@ -2971,20 +3022,19 @@ static int g_work_group_size = 0;
2971
3022
  // typedef sycl::half ggml_fp16_t;
2972
3023
 
2973
3024
  #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
2974
- #define VER_4VEC 610 //todo for hardward optimize.
3025
+ #define VER_4VEC 130 //todo for hardward optimize.
2975
3026
  #define VER_GEN9 700 //todo for hardward optimize.
2976
3027
  #define VER_GEN12 1000000 //todo for hardward optimize.
2977
3028
  #define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize.
2978
3029
 
2979
3030
  #define GGML_SYCL_MAX_NODES 8192 //TODO: adapt to hardwares
2980
3031
 
2981
-
2982
- //define for XMX in Intel GPU
2983
- //TODO: currently, it's not used for XMX really.
2984
- #define SYCL_USE_XMX
3032
+ #if !defined(GGML_SYCL_FORCE_MMQ)
3033
+ #define SYCL_USE_XMX
3034
+ #endif
2985
3035
 
2986
3036
  // max batch size to use MMQ kernels when tensor cores are available
2987
- #define XMX_MAX_BATCH_SIZE 32
3037
+ #define MMQ_MAX_BATCH_SIZE 32
2988
3038
 
2989
3039
 
2990
3040
  #if defined(_MSC_VER)
@@ -3060,6 +3110,7 @@ void ggml_sycl_get_device_description(int device, char * description, size_t d
3060
3110
  bool ggml_backend_is_sycl(ggml_backend_t backend);
3061
3111
  int ggml_backend_sycl_get_device(ggml_backend_t backend);
3062
3112
  int get_main_device();
3113
+ static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
3063
3114
  void print_ggml_tensor(const char*name, struct ggml_tensor *src);
3064
3115
  void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt);
3065
3116
 
@@ -8830,12 +8881,11 @@ static void rope(
8830
8881
  dst[i + 1] = x0*sin_theta + x1*cos_theta;
8831
8882
  }
8832
8883
 
8833
- template<typename T, bool has_pos>
8884
+ template<typename T, bool has_pos, bool has_freq_facs>
8834
8885
  static void rope_neox(
8835
8886
  const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
8836
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
8837
- ,
8838
- const sycl::nd_item<3> &item_ct1) {
8887
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
8888
+ const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
8839
8889
  const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
8840
8890
  item_ct1.get_local_id(1));
8841
8891
 
@@ -8863,8 +8913,10 @@ static void rope_neox(
8863
8913
  float cur_rot = inv_ndims * ic - ib;
8864
8914
 
8865
8915
  const int p = has_pos ? pos[i2] : 0;
8916
+ const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
8917
+
8866
8918
  const float theta_base =
8867
- p * freq_scale * dpct::pow(theta_scale, col / 2.0f);
8919
+ p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
8868
8920
 
8869
8921
  float cos_theta, sin_theta;
8870
8922
  rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -8876,49 +8928,6 @@ static void rope_neox(
8876
8928
  dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
8877
8929
  }
8878
8930
 
8879
- static void rope_glm_f32(
8880
- const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
8881
- int n_ctx
8882
- , const sycl::nd_item<3> &item_ct1) {
8883
- const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
8884
- item_ct1.get_local_id(2);
8885
- const int half_n_dims = ncols/4;
8886
-
8887
- if (col >= half_n_dims) {
8888
- return;
8889
- }
8890
-
8891
- const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
8892
- item_ct1.get_local_id(1);
8893
- const int i = row*ncols + col;
8894
- const int i2 = row/p_delta_rows;
8895
-
8896
- const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols);
8897
- // FIXME: this is likely wrong
8898
- const int p = pos != nullptr ? pos[i2] : 0;
8899
-
8900
- const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale;
8901
- const float sin_theta = sycl::sin((float)theta);
8902
- const float cos_theta = sycl::cos((float)theta);
8903
-
8904
- const float x0 = x[i + 0];
8905
- const float x1 = x[i + half_n_dims];
8906
-
8907
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
8908
- dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
8909
-
8910
- const float block_theta =
8911
- ((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale;
8912
- const float sin_block_theta = sycl::sin((float)block_theta);
8913
- const float cos_block_theta = sycl::cos((float)block_theta);
8914
-
8915
- const float x2 = x[i + half_n_dims * 2];
8916
- const float x3 = x[i + half_n_dims * 3];
8917
-
8918
- dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
8919
- dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
8920
- }
8921
-
8922
8931
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
8923
8932
  const sycl::nd_item<3> &item_ct1) {
8924
8933
  const int row = item_ct1.get_group(1);
@@ -12413,7 +12422,7 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12413
12422
  const int32_t *pos, float freq_scale,
12414
12423
  int p_delta_rows, float freq_base, float ext_factor,
12415
12424
  float attn_factor, rope_corr_dims corr_dims,
12416
- dpct::queue_ptr stream) {
12425
+ const float * freq_factors, dpct::queue_ptr stream) {
12417
12426
  GGML_ASSERT(ncols % 2 == 0);
12418
12427
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
12419
12428
  const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
@@ -12423,57 +12432,51 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12423
12432
  const float inv_ndims = -1.0f / n_dims;
12424
12433
 
12425
12434
  if (pos == nullptr) {
12426
- /*
12427
- DPCT1049:42: The work-group size passed to the SYCL kernel may exceed
12428
- the limit. To get the device limit, query
12429
- info::device::max_work_group_size. Adjust the work-group size if needed.
12430
- */
12431
12435
  dpct::has_capability_or_fail(stream->get_device(),
12432
12436
  {sycl::aspect::fp16});
12433
-
12434
- stream->parallel_for(
12435
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12436
- [=](sycl::nd_item<3> item_ct1) {
12437
- rope_neox<T, false>(x, dst, ncols, n_dims, pos, freq_scale,
12438
- p_delta_rows, ext_factor, attn_factor,
12439
- corr_dims, theta_scale, inv_ndims,
12440
- item_ct1);
12441
- });
12437
+ if (freq_factors == nullptr) {
12438
+ stream->parallel_for(
12439
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12440
+ [=](sycl::nd_item<3> item_ct1) {
12441
+ rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
12442
+ p_delta_rows, ext_factor, attn_factor,
12443
+ corr_dims, theta_scale, inv_ndims, freq_factors,
12444
+ item_ct1);
12445
+ });
12446
+ } else {
12447
+ stream->parallel_for(
12448
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12449
+ [=](sycl::nd_item<3> item_ct1) {
12450
+ rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
12451
+ p_delta_rows, ext_factor, attn_factor,
12452
+ corr_dims, theta_scale, inv_ndims, freq_factors,
12453
+ item_ct1);
12454
+ });
12455
+ }
12442
12456
  } else {
12443
- /*
12444
- DPCT1049:43: The work-group size passed to the SYCL kernel may exceed
12445
- the limit. To get the device limit, query
12446
- info::device::max_work_group_size. Adjust the work-group size if needed.
12447
- */
12448
12457
  dpct::has_capability_or_fail(stream->get_device(),
12449
12458
  {sycl::aspect::fp16});
12450
12459
 
12451
- stream->parallel_for(
12452
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12453
- [=](sycl::nd_item<3> item_ct1) {
12454
- rope_neox<T, true>(x, dst, ncols, n_dims, pos, freq_scale,
12455
- p_delta_rows, ext_factor, attn_factor,
12456
- corr_dims, theta_scale, inv_ndims, item_ct1);
12457
- });
12460
+ if (freq_factors == nullptr) {
12461
+ stream->parallel_for(
12462
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12463
+ [=](sycl::nd_item<3> item_ct1) {
12464
+ rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
12465
+ p_delta_rows, ext_factor, attn_factor,
12466
+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12467
+ });
12468
+ } else {
12469
+ stream->parallel_for(
12470
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12471
+ [=](sycl::nd_item<3> item_ct1) {
12472
+ rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
12473
+ p_delta_rows, ext_factor, attn_factor,
12474
+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12475
+ });
12476
+ }
12458
12477
  }
12459
12478
  }
12460
12479
 
12461
- static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
12462
- const int32_t *pos, float freq_scale,
12463
- int p_delta_rows, float freq_base, int n_ctx,
12464
- dpct::queue_ptr stream) {
12465
- GGML_ASSERT(ncols % 4 == 0);
12466
- const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4);
12467
- const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE;
12468
- const sycl::range<3> block_nums(1, nrows, num_blocks_x);
12469
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
12470
- [=](sycl::nd_item<3> item_ct1) {
12471
- rope_glm_f32(x, dst, ncols, pos, freq_scale,
12472
- p_delta_rows, freq_base, n_ctx,
12473
- item_ct1);
12474
- });
12475
- }
12476
-
12477
12480
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
12478
12481
  const int nrows, dpct::queue_ptr stream) {
12479
12482
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -13501,6 +13504,10 @@ inline void ggml_sycl_op_concat(const ggml_tensor *src0,
13501
13504
  const float *src0_dd, const float *src1_dd,
13502
13505
  float *dst_dd,
13503
13506
  const dpct::queue_ptr &main_stream) {
13507
+ #pragma message("TODO: generalize concat kernel for dim != 2")
13508
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563")
13509
+ int dim = dst->op_params[0];
13510
+ GGML_ASSERT(dim == 2);
13504
13511
 
13505
13512
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
13506
13513
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -13986,9 +13993,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
13986
13993
  ggml_tensor *dst, const float *src0_dd,
13987
13994
  const float *src1_dd, float *dst_dd,
13988
13995
  const dpct::queue_ptr &main_stream) {
13989
- #pragma message("TODO: implement phi3 frequency factors support")
13990
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
13991
- GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
13996
+ const ggml_tensor * src2 = dst->src[2];
13992
13997
 
13993
13998
  GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
13994
13999
  GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -14002,8 +14007,8 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14002
14007
  //const int n_past = ((int32_t *) dst->op_params)[0];
14003
14008
  const int n_dims = ((int32_t *) dst->op_params)[1];
14004
14009
  const int mode = ((int32_t *) dst->op_params)[2];
14005
- const int n_ctx = ((int32_t *) dst->op_params)[3];
14006
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
14010
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
14011
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
14007
14012
 
14008
14013
  // RoPE alteration for extended context
14009
14014
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -14014,6 +14019,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14014
14019
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
14015
14020
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
14016
14021
 
14022
+ const float * freq_factors = nullptr;
14017
14023
  const int32_t * pos = nullptr;
14018
14024
  if ((mode & 1) == 0) {
14019
14025
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -14022,26 +14028,35 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14022
14028
  }
14023
14029
 
14024
14030
  const bool is_neox = mode & 2;
14025
- const bool is_glm = mode & 4;
14031
+
14032
+ #pragma message("TODO: update rope NORM mode to match NEOX mode")
14033
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
14034
+
14035
+ if (is_neox) {
14036
+ pos = (const int32_t *) src1_dd;
14037
+
14038
+ if (src2 != nullptr) {
14039
+ freq_factors = (const float *) src2->data;
14040
+ }
14041
+ } else {
14042
+ GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
14043
+ }
14026
14044
 
14027
14045
  rope_corr_dims corr_dims;
14028
- ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
14046
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
14029
14047
 
14030
14048
  // compute
14031
- if (is_glm) {
14032
- GGML_ASSERT(false);
14033
- rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
14034
- } else if (is_neox) {
14049
+ if (is_neox) {
14035
14050
  if (src0->type == GGML_TYPE_F32) {
14036
14051
  rope_neox_sycl(
14037
14052
  (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
14038
- attn_factor, corr_dims, main_stream
14053
+ attn_factor, corr_dims, freq_factors, main_stream
14039
14054
  );
14040
14055
  } else if (src0->type == GGML_TYPE_F16) {
14041
14056
  rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
14042
14057
  ne00, n_dims, nrows, pos, freq_scale, ne01,
14043
14058
  freq_base, ext_factor, attn_factor, corr_dims,
14044
- main_stream);
14059
+ freq_factors, main_stream);
14045
14060
  } else {
14046
14061
  GGML_ASSERT(false);
14047
14062
  }
@@ -15108,7 +15123,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15108
15123
  const int64_t r2 = ne12/ne02;
15109
15124
  const int64_t r3 = ne13/ne03;
15110
15125
 
15111
- if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
15126
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
15112
15127
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
15113
15128
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
15114
15129
  *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
@@ -15173,6 +15188,29 @@ catch (sycl::exception const &exc) {
15173
15188
  std::exit(1);
15174
15189
  }
15175
15190
 
15191
+ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
15192
+ // TODO: accuracy issues in MMQ
15193
+ return false;
15194
+ }
15195
+
15196
+ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
15197
+ switch (type) {
15198
+ case GGML_TYPE_Q4_0:
15199
+ case GGML_TYPE_Q4_1:
15200
+ case GGML_TYPE_Q5_0:
15201
+ case GGML_TYPE_Q5_1:
15202
+ case GGML_TYPE_Q8_0:
15203
+ case GGML_TYPE_Q2_K:
15204
+ case GGML_TYPE_Q3_K:
15205
+ case GGML_TYPE_Q4_K:
15206
+ case GGML_TYPE_Q5_K:
15207
+ case GGML_TYPE_Q6_K:
15208
+ case GGML_TYPE_F16:
15209
+ return true;
15210
+ default:
15211
+ return false;
15212
+ }
15213
+ }
15176
15214
 
15177
15215
  static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
15178
15216
  const bool all_on_device =
@@ -15189,75 +15227,42 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
15189
15227
  }
15190
15228
  }
15191
15229
 
15192
- #ifdef SYCL_USE_XMX
15193
- const bool use_xmx = true;
15194
- #else
15195
- const bool use_xmx = false;
15196
- #endif
15230
+ // check data types and tensor shapes for custom matrix multiplication kernels:
15231
+ bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
15232
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15233
+ && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
15234
+
15235
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
15236
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15237
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
15197
15238
 
15198
- // debug helpers
15199
- //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
15200
- //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
15201
- //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
15202
- //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
15203
- //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
15204
- //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
15239
+ bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
15240
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
15205
15241
 
15206
- if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15242
+ // mmvq and mmq need the __dp4a instruction which is available for gen12+
15243
+ // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
15244
+ use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
15245
+ #ifdef SYCL_USE_XMX
15246
+ use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
15247
+ #endif // SYCL_USE_XMX
15248
+
15249
+ if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15207
15250
  // KQ single-batch
15208
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
15209
15251
  ggml_sycl_mul_mat_vec_p021(src0, src1, dst);
15210
- } else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15252
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15211
15253
  // KQV single-batch
15212
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
15213
15254
  ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
15214
- } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
15255
+ } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
15215
15256
  // KQ + KQV multi-batch
15216
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
15217
15257
  ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
15218
- } else if (src0->type == GGML_TYPE_F32) {
15219
- // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
15220
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15221
- } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
15222
- // GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n");
15223
- if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) {
15224
- #ifdef GGML_SYCL_FORCE_DMMV
15225
- const bool use_mul_mat_vec_q = false;
15226
- #else
15227
- bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15228
- use_mul_mat_vec_q = use_mul_mat_vec_q ||
15229
- (src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
15230
- (src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
15231
- (src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
15232
- (src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
15233
-
15234
-
15235
- #endif // GGML_SYCL_FORCE_DMMV
15236
-
15237
- if (use_mul_mat_vec_q) {
15238
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
15239
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15240
- } else {
15241
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n");
15242
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15243
- }
15244
- } else {
15245
- bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15246
-
15247
- if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
15248
- use_mul_mat_q = false;
15249
- }
15250
-
15251
- if (use_mul_mat_q) {
15252
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n");
15253
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15254
- } else {
15255
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
15256
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15257
- }
15258
- }
15258
+ } else if (use_dequantize_mul_mat_vec) {
15259
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15260
+ } else if (use_mul_mat_vec_q) {
15261
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15262
+ } else if (use_mul_mat_q) {
15263
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15259
15264
  } else {
15260
- GGML_ASSERT(false);
15265
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15261
15266
  }
15262
15267
  }
15263
15268
 
@@ -15434,22 +15439,86 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
15434
15439
  }
15435
15440
  #endif
15436
15441
 
15442
+ struct mmid_row_mapping {
15443
+ int32_t i1;
15444
+ int32_t i2;
15445
+ };
15446
+
15447
+ __dpct_inline__ static void k_copy_src1_to_contiguous(
15448
+ const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
15449
+ int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
15450
+ const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
15451
+ int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
15452
+ const sycl::nd_item<3> &item_ct1, int &src1_row) {
15453
+ int32_t iid1 = item_ct1.get_group(2);
15454
+ int32_t id = item_ct1.get_group(1);
15455
+
15456
+ const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
15457
+
15458
+ if (row_id_i != i02) {
15459
+ return;
15460
+ }
15461
+
15462
+ const int64_t i11 = id % ne11;
15463
+ const int64_t i12 = iid1;
15464
+
15465
+ if (item_ct1.get_local_id(2) == 0) {
15466
+ src1_row =
15467
+ dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
15468
+ cur_src1_row, 1);
15469
+ row_mapping[src1_row] = {id, iid1};
15470
+ }
15471
+ /*
15472
+ DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
15473
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
15474
+ performance if there is no access to global memory.
15475
+ */
15476
+ item_ct1.barrier();
15477
+
15478
+ const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
15479
+ float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
15480
+
15481
+ #pragma unroll
15482
+ for (int i = item_ct1.get_local_id(2); i < ne10;
15483
+ i += item_ct1.get_local_range(2)) {
15484
+ src1_row_contiguous[i] = src1_row_original[i];
15485
+ }
15486
+ }
15487
+
15488
+ __dpct_inline__ static void k_copy_dst_from_contiguous(
15489
+ char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
15490
+ const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
15491
+ size_t nb2, const sycl::nd_item<3> &item_ct1) {
15492
+ int32_t i = item_ct1.get_group(2);
15493
+
15494
+ const int32_t i1 = row_mapping[i].i1;
15495
+ const int32_t i2 = row_mapping[i].i2;
15496
+
15497
+ const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
15498
+ float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
15499
+
15500
+ #pragma unroll
15501
+ for (int j = item_ct1.get_local_id(2); j < ne0;
15502
+ j += item_ct1.get_local_range(2)) {
15503
+ dst_row_original[j] = dst_row_contiguous[j];
15504
+ }
15505
+ }
15506
+
15437
15507
  static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15438
15508
  const ggml_tensor *src1,
15439
15509
  ggml_tensor *dst) try {
15440
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
15441
- "mul_mat_id does not support split buffers");
15510
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
15511
+
15442
15512
  const ggml_tensor *ids = dst->src[2];
15443
- const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15513
+ GGML_TENSOR_BINARY_OP_LOCALS
15444
15514
 
15445
- const size_t nb11 = src1->nb[1];
15446
- const size_t nb1 = dst->nb[1];
15515
+ const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15447
15516
 
15448
- const int32_t id = ((int32_t *)dst->op_params)[0];
15449
- const int32_t n_as = src0->ne[2];
15517
+ const int64_t n_as = ne02;
15518
+ const int64_t n_ids = ids->ne[0];
15450
15519
 
15451
15520
  std::vector<char> ids_host(ggml_nbytes(ids));
15452
- const char *ids_dev = (const char *)ids->data;
15521
+ const char * ids_dev = (const char *) ids->data;
15453
15522
 
15454
15523
  SYCL_CHECK(CHECK_TRY_ERROR(
15455
15524
  stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
@@ -15489,24 +15558,40 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15489
15558
 
15490
15559
  src0_row.ne[2] = 1;
15491
15560
  src0_row.ne[3] = 1;
15492
- src0_row.nb[3] = src0->nb[2];
15493
-
15494
- if (src1->ne[1] == 1) {
15495
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15496
- const int32_t row_id =
15497
- *(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
15498
- id * ids->nb[0]);
15499
-
15500
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15561
+ src0_row.nb[3] = nb02;
15562
+
15563
+ src1_row.ne[1] = 1;
15564
+ src1_row.ne[2] = 1;
15565
+ src1_row.ne[3] = 1;
15566
+ src1_row.nb[2] = nb11;
15567
+ src1_row.nb[3] = nb11;
15568
+
15569
+ dst_row.ne[1] = 1;
15570
+ dst_row.ne[2] = 1;
15571
+ dst_row.ne[3] = 1;
15572
+ dst_row.nb[2] = nb1;
15573
+ dst_row.nb[3] = nb1;
15574
+ if (ne12 == 1) {
15575
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15576
+ for (int64_t id = 0; id < n_ids; id++) {
15577
+ const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
15578
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
15579
+
15580
+ const int64_t i11 = id % ne11;
15581
+ const int64_t i12 = iid1;
15582
+
15583
+ const int64_t i1 = id;
15584
+ const int64_t i2 = i12;
15501
15585
 
15502
15586
  src0_row_extra.data_device[g_main_device] =
15503
- src0_original + row_id * src0->nb[2];
15587
+ src0_original + i02*nb02;
15504
15588
  src1_row_extra.data_device[g_main_device] =
15505
- src1_original + i01 * src1->nb[1];
15589
+ src1_original + + i11*nb11 + i12*nb12;
15506
15590
  dst_row_extra.data_device[g_main_device] =
15507
- dst_original + i01 * dst->nb[1];
15591
+ dst_original + i1*nb1 + i2*nb2;
15508
15592
 
15509
15593
  ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
15594
+ }
15510
15595
  }
15511
15596
  } else {
15512
15597
  sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
@@ -15515,64 +15600,98 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15515
15600
  src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
15516
15601
  dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
15517
15602
 
15518
- for (int32_t row_id = 0; row_id < n_as; ++row_id) {
15603
+ for (int64_t i02 = 0; i02 < n_as; i02++) {
15519
15604
  int64_t num_src1_rows = 0;
15520
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15521
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
15605
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15606
+ for (int64_t id = 0; id < n_ids; id++) {
15607
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
15522
15608
 
15523
- if (row_id_i != row_id) {
15524
- continue;
15525
- }
15609
+ GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
15526
15610
 
15527
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15611
+ if (row_id_i != i02) {
15612
+ continue;
15613
+ }
15528
15614
 
15529
- SYCL_CHECK(CHECK_TRY_ERROR(
15530
- stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
15531
- src1_original + i01 * nb11, nb11)));
15532
- num_src1_rows++;
15615
+ num_src1_rows++;
15616
+ }
15533
15617
  }
15534
15618
 
15535
15619
  if (num_src1_rows == 0) {
15536
15620
  continue;
15537
15621
  }
15538
15622
 
15539
- src0_row_extra.data_device[g_main_device] =
15540
- src0_original + row_id * src0->nb[2];
15541
15623
 
15624
+ sycl_pool_alloc<int> dev_cur_src1_row(1);
15625
+ sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(num_src1_rows);
15626
+ SYCL_CHECK(CHECK_TRY_ERROR(
15627
+ stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
15628
+
15629
+ {
15630
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
15631
+ sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
15632
+ stream->submit([&](sycl::handler &cgh) {
15633
+ sycl::local_accessor<int, 0> src1_row_acc(cgh);
15634
+
15635
+ char *__restrict src1_contiguous_get =
15636
+ src1_contiguous.get();
15637
+ int *__restrict dev_cur_src1_row_get =
15638
+ dev_cur_src1_row.get();
15639
+ mmid_row_mapping *__restrict dev_row_mapping_get =
15640
+ dev_row_mapping.get();
15641
+ size_t ids_nb_ct6 = ids->nb[1];
15642
+ size_t ids_nb_ct7 = ids->nb[0];
15643
+
15644
+ cgh.parallel_for(
15645
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15646
+ [=](sycl::nd_item<3> item_ct1) {
15647
+ k_copy_src1_to_contiguous(
15648
+ src1_original, src1_contiguous_get,
15649
+ dev_cur_src1_row_get,
15650
+ dev_row_mapping_get, ids_dev, i02,
15651
+ ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
15652
+ item_ct1, src1_row_acc);
15653
+ });
15654
+ });
15655
+ }
15656
+
15657
+ src0_row_extra.data_device[g_main_device] = src0_original + i02*nb02;
15658
+
15659
+ GGML_ASSERT(nb11 == sizeof(float)*ne10);
15660
+ GGML_ASSERT(nb1 == sizeof(float)*ne0);
15542
15661
  src1_row.ne[1] = num_src1_rows;
15543
- dst_row.ne[1] = num_src1_rows;
15544
15662
 
15545
15663
  src1_row.nb[1] = nb11;
15546
15664
  src1_row.nb[2] = num_src1_rows*nb11;
15547
15665
  src1_row.nb[3] = num_src1_rows*nb11;
15548
15666
 
15667
+ dst_row.ne[1] = num_src1_rows;
15549
15668
  dst_row.nb[1] = nb1;
15550
15669
  dst_row.nb[2] = num_src1_rows*nb1;
15551
15670
  dst_row.nb[3] = num_src1_rows*nb1;
15552
15671
 
15553
15672
  ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
15554
15673
 
15555
- num_src1_rows = 0;
15556
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15557
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
15558
-
15559
- if (row_id_i != row_id) {
15560
- continue;
15561
- }
15562
-
15563
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15564
-
15565
- SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
15566
- dst_original + i01 * nb1,
15567
- dst_contiguous.get() + num_src1_rows * nb1, nb1)));
15568
- num_src1_rows++;
15674
+ {
15675
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
15676
+ sycl::range<3> grid_dims(1, 1, num_src1_rows);
15677
+ stream->submit([&](sycl::handler &cgh) {
15678
+ const char *__restrict dst_contiguous_get =
15679
+ dst_contiguous.get();
15680
+ const mmid_row_mapping *__restrict dev_row_mapping_get =
15681
+ dev_row_mapping.get();
15682
+
15683
+ cgh.parallel_for(
15684
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15685
+ [=](sycl::nd_item<3> item_ct1) {
15686
+ k_copy_dst_from_contiguous(dst_original,
15687
+ dst_contiguous_get,
15688
+ dev_row_mapping_get,
15689
+ ne0, nb1, nb2, item_ct1);
15690
+ });
15691
+ });
15569
15692
  }
15570
15693
  }
15571
15694
  }
15572
-
15573
- if (dst->backend == GGML_BACKEND_TYPE_CPU) {
15574
- SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
15575
- }
15576
15695
  }
15577
15696
  catch (sycl::exception const &exc) {
15578
15697
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -16555,10 +16674,9 @@ GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backe
16555
16674
  UNUSED(buffer);
16556
16675
  }
16557
16676
 
16558
- // unused at the moment
16559
- //static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
16560
- // return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
16561
- //}
16677
+ static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
16678
+ return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
16679
+ }
16562
16680
 
16563
16681
  GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
16564
16682
  ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;