llama_cpp 0.15.3 → 0.16.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (149) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/ext/llama_cpp/extconf.rb +1 -2
  4. data/ext/llama_cpp/llama_cpp.cpp +27 -3
  5. data/lib/llama_cpp/version.rb +2 -2
  6. data/sig/llama_cpp.rbs +15 -1
  7. data/vendor/tmp/llama.cpp/Makefile +66 -36
  8. data/vendor/tmp/llama.cpp/ggml-alloc.c +4 -4
  9. data/vendor/tmp/llama.cpp/ggml-backend.c +5 -5
  10. data/vendor/tmp/llama.cpp/ggml-backend.h +1 -1
  11. data/vendor/tmp/llama.cpp/ggml-cuda/acc.cu +47 -0
  12. data/vendor/tmp/llama.cpp/ggml-cuda/arange.cu +34 -0
  13. data/vendor/tmp/llama.cpp/ggml-cuda/argsort.cu +103 -0
  14. data/vendor/tmp/llama.cpp/ggml-cuda/binbcast.cu +280 -0
  15. data/vendor/tmp/llama.cpp/ggml-cuda/clamp.cu +34 -0
  16. data/vendor/tmp/llama.cpp/ggml-cuda/concat.cu +196 -0
  17. data/vendor/tmp/llama.cpp/ggml-cuda/convert.cu +686 -0
  18. data/vendor/tmp/llama.cpp/ggml-cuda/cpy.cu +490 -0
  19. data/vendor/tmp/llama.cpp/ggml-cuda/diagmask.cu +40 -0
  20. data/vendor/tmp/llama.cpp/ggml-cuda/dmmv.cu +662 -0
  21. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f16.cu +319 -0
  22. data/vendor/tmp/llama.cpp/ggml-cuda/fattn-tile-f32.cu +312 -0
  23. data/vendor/tmp/llama.cpp/ggml-cuda/fattn.cu +345 -0
  24. data/vendor/tmp/llama.cpp/ggml-cuda/getrows.cu +178 -0
  25. data/vendor/tmp/llama.cpp/ggml-cuda/im2col.cu +104 -0
  26. data/vendor/tmp/llama.cpp/ggml-cuda/mmq.cu +1564 -0
  27. data/vendor/tmp/llama.cpp/ggml-cuda/mmvq.cu +404 -0
  28. data/vendor/tmp/llama.cpp/ggml-cuda/norm.cu +221 -0
  29. data/vendor/tmp/llama.cpp/ggml-cuda/pad.cu +49 -0
  30. data/vendor/tmp/llama.cpp/ggml-cuda/pool2d.cu +94 -0
  31. data/vendor/tmp/llama.cpp/ggml-cuda/quantize.cu +45 -0
  32. data/vendor/tmp/llama.cpp/ggml-cuda/rope.cu +271 -0
  33. data/vendor/tmp/llama.cpp/ggml-cuda/scale.cu +31 -0
  34. data/vendor/tmp/llama.cpp/ggml-cuda/softmax.cu +205 -0
  35. data/vendor/tmp/llama.cpp/ggml-cuda/sumrows.cu +40 -0
  36. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  37. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  38. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  39. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  40. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  41. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  42. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  43. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  44. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  45. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  46. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  47. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  48. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  49. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  50. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  51. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  52. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  53. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  54. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  55. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  56. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  57. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  58. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  59. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  60. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  61. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  62. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  63. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  64. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  65. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  66. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  67. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  68. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  69. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  70. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  71. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  72. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  73. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  74. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  75. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  76. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  77. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  78. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  79. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  80. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  81. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  82. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  83. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  84. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  85. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  86. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  87. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  88. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  89. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  90. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  91. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  92. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  93. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  94. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  95. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  96. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  97. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  98. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  99. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  100. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  101. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  102. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  103. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  104. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  105. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  106. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  107. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  108. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  109. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  110. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  111. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  112. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  113. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  114. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  115. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  116. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  117. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  118. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  119. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  120. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  121. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  122. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +10 -0
  123. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +9 -0
  124. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +10 -0
  125. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +10 -0
  126. data/vendor/tmp/llama.cpp/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +8 -0
  127. data/vendor/tmp/llama.cpp/ggml-cuda/tsembd.cu +47 -0
  128. data/vendor/tmp/llama.cpp/ggml-cuda/unary.cu +266 -0
  129. data/vendor/tmp/llama.cpp/ggml-cuda/upscale.cu +51 -0
  130. data/vendor/tmp/llama.cpp/ggml-cuda.cu +35 -16
  131. data/vendor/tmp/llama.cpp/ggml-impl.h +4 -0
  132. data/vendor/tmp/llama.cpp/ggml-kompute.cpp +21 -7
  133. data/vendor/tmp/llama.cpp/ggml-metal.h +1 -1
  134. data/vendor/tmp/llama.cpp/ggml-metal.m +99 -35
  135. data/vendor/tmp/llama.cpp/ggml-metal.metal +146 -80
  136. data/vendor/tmp/llama.cpp/ggml-quants.c +101 -11
  137. data/vendor/tmp/llama.cpp/ggml-rpc.cpp +75 -58
  138. data/vendor/tmp/llama.cpp/ggml-sycl.cpp +345 -227
  139. data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +99301 -39793
  140. data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +458 -329
  141. data/vendor/tmp/llama.cpp/ggml.c +301 -409
  142. data/vendor/tmp/llama.cpp/ggml.h +19 -23
  143. data/vendor/tmp/llama.cpp/llama.cpp +855 -651
  144. data/vendor/tmp/llama.cpp/llama.h +28 -48
  145. metadata +121 -6
  146. data/vendor/tmp/llama.cpp/ggml-mpi.c +0 -216
  147. data/vendor/tmp/llama.cpp/ggml-mpi.h +0 -39
  148. data/vendor/tmp/llama.cpp/ggml-opencl.cpp +0 -2305
  149. data/vendor/tmp/llama.cpp/ggml-opencl.h +0 -36
@@ -2944,6 +2944,57 @@ namespace dpct
2944
2944
  using shared_memory = detail::device_memory<T, shared, Dimension>;
2945
2945
 
2946
2946
 
2947
+ template <typename T,
2948
+ sycl::access::address_space addressSpace =
2949
+ sycl::access::address_space::global_space,
2950
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2951
+ sycl::memory_scope memoryScope = sycl::memory_scope::device>
2952
+ inline T atomic_fetch_add(T *addr, T operand) {
2953
+ auto atm =
2954
+ sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
2955
+ return atm.fetch_add(operand);
2956
+ }
2957
+
2958
+ template <sycl::access::address_space addressSpace =
2959
+ sycl::access::address_space::global_space,
2960
+ sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
2961
+ sycl::memory_scope memoryScope = sycl::memory_scope::device,
2962
+ typename T1, typename T2>
2963
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
2964
+ auto atm =
2965
+ sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
2966
+ return atm.fetch_add(operand);
2967
+ }
2968
+
2969
+ template <typename T, sycl::access::address_space addressSpace =
2970
+ sycl::access::address_space::global_space>
2971
+ inline T atomic_fetch_add(T *addr, T operand,
2972
+ sycl::memory_order memoryOrder) {
2973
+ switch (memoryOrder) {
2974
+ case sycl::memory_order::relaxed:
2975
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
2976
+ sycl::memory_scope::device>(addr, operand);
2977
+ case sycl::memory_order::acq_rel:
2978
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
2979
+ sycl::memory_scope::device>(addr, operand);
2980
+ case sycl::memory_order::seq_cst:
2981
+ return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
2982
+ sycl::memory_scope::device>(addr, operand);
2983
+ default:
2984
+ assert(false && "Invalid memory_order for atomics. Valid memory_order for "
2985
+ "atomics are: sycl::memory_order::relaxed, "
2986
+ "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
2987
+ }
2988
+ }
2989
+
2990
+ template <sycl::access::address_space addressSpace =
2991
+ sycl::access::address_space::global_space,
2992
+ typename T1, typename T2>
2993
+ inline T1 atomic_fetch_add(T1 *addr, T2 operand,
2994
+ sycl::memory_order memoryOrder) {
2995
+ atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
2996
+ }
2997
+
2947
2998
  } // COPY from DPCT head files
2948
2999
 
2949
3000
  #define GGML_COMMON_DECL_SYCL
@@ -2971,20 +3022,19 @@ static int g_work_group_size = 0;
2971
3022
  // typedef sycl::half ggml_fp16_t;
2972
3023
 
2973
3024
  #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
2974
- #define VER_4VEC 610 //todo for hardward optimize.
3025
+ #define VER_4VEC 130 //todo for hardward optimize.
2975
3026
  #define VER_GEN9 700 //todo for hardward optimize.
2976
3027
  #define VER_GEN12 1000000 //todo for hardward optimize.
2977
3028
  #define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize.
2978
3029
 
2979
3030
  #define GGML_SYCL_MAX_NODES 8192 //TODO: adapt to hardwares
2980
3031
 
2981
-
2982
- //define for XMX in Intel GPU
2983
- //TODO: currently, it's not used for XMX really.
2984
- #define SYCL_USE_XMX
3032
+ #if !defined(GGML_SYCL_FORCE_MMQ)
3033
+ #define SYCL_USE_XMX
3034
+ #endif
2985
3035
 
2986
3036
  // max batch size to use MMQ kernels when tensor cores are available
2987
- #define XMX_MAX_BATCH_SIZE 32
3037
+ #define MMQ_MAX_BATCH_SIZE 32
2988
3038
 
2989
3039
 
2990
3040
  #if defined(_MSC_VER)
@@ -3060,6 +3110,7 @@ void ggml_sycl_get_device_description(int device, char * description, size_t d
3060
3110
  bool ggml_backend_is_sycl(ggml_backend_t backend);
3061
3111
  int ggml_backend_sycl_get_device(ggml_backend_t backend);
3062
3112
  int get_main_device();
3113
+ static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
3063
3114
  void print_ggml_tensor(const char*name, struct ggml_tensor *src);
3064
3115
  void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt);
3065
3116
 
@@ -8830,12 +8881,11 @@ static void rope(
8830
8881
  dst[i + 1] = x0*sin_theta + x1*cos_theta;
8831
8882
  }
8832
8883
 
8833
- template<typename T, bool has_pos>
8884
+ template<typename T, bool has_pos, bool has_freq_facs>
8834
8885
  static void rope_neox(
8835
8886
  const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
8836
- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
8837
- ,
8838
- const sycl::nd_item<3> &item_ct1) {
8887
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims,
8888
+ const float * freq_factors, const sycl::nd_item<3> &item_ct1) {
8839
8889
  const int col = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
8840
8890
  item_ct1.get_local_id(1));
8841
8891
 
@@ -8863,8 +8913,10 @@ static void rope_neox(
8863
8913
  float cur_rot = inv_ndims * ic - ib;
8864
8914
 
8865
8915
  const int p = has_pos ? pos[i2] : 0;
8916
+ const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
8917
+
8866
8918
  const float theta_base =
8867
- p * freq_scale * dpct::pow(theta_scale, col / 2.0f);
8919
+ p * freq_scale * dpct::pow(theta_scale, col / 2.0f)/freq_factor;
8868
8920
 
8869
8921
  float cos_theta, sin_theta;
8870
8922
  rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
@@ -8876,49 +8928,6 @@ static void rope_neox(
8876
8928
  dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
8877
8929
  }
8878
8930
 
8879
- static void rope_glm_f32(
8880
- const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
8881
- int n_ctx
8882
- , const sycl::nd_item<3> &item_ct1) {
8883
- const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
8884
- item_ct1.get_local_id(2);
8885
- const int half_n_dims = ncols/4;
8886
-
8887
- if (col >= half_n_dims) {
8888
- return;
8889
- }
8890
-
8891
- const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
8892
- item_ct1.get_local_id(1);
8893
- const int i = row*ncols + col;
8894
- const int i2 = row/p_delta_rows;
8895
-
8896
- const float col_theta_scale = dpct::pow(freq_base, -2.0f * col / ncols);
8897
- // FIXME: this is likely wrong
8898
- const int p = pos != nullptr ? pos[i2] : 0;
8899
-
8900
- const float theta = sycl::min(p, n_ctx - 2) * freq_scale * col_theta_scale;
8901
- const float sin_theta = sycl::sin((float)theta);
8902
- const float cos_theta = sycl::cos((float)theta);
8903
-
8904
- const float x0 = x[i + 0];
8905
- const float x1 = x[i + half_n_dims];
8906
-
8907
- dst[i + 0] = x0*cos_theta - x1*sin_theta;
8908
- dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
8909
-
8910
- const float block_theta =
8911
- ((float)sycl::max(p - n_ctx - 2, 0)) * col_theta_scale;
8912
- const float sin_block_theta = sycl::sin((float)block_theta);
8913
- const float cos_block_theta = sycl::cos((float)block_theta);
8914
-
8915
- const float x2 = x[i + half_n_dims * 2];
8916
- const float x3 = x[i + half_n_dims * 3];
8917
-
8918
- dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
8919
- dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
8920
- }
8921
-
8922
8931
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
8923
8932
  const sycl::nd_item<3> &item_ct1) {
8924
8933
  const int row = item_ct1.get_group(1);
@@ -12413,7 +12422,7 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12413
12422
  const int32_t *pos, float freq_scale,
12414
12423
  int p_delta_rows, float freq_base, float ext_factor,
12415
12424
  float attn_factor, rope_corr_dims corr_dims,
12416
- dpct::queue_ptr stream) {
12425
+ const float * freq_factors, dpct::queue_ptr stream) {
12417
12426
  GGML_ASSERT(ncols % 2 == 0);
12418
12427
  const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
12419
12428
  const int num_blocks_x = (ncols + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
@@ -12423,57 +12432,51 @@ static void rope_neox_sycl(const T *x, T *dst, int ncols, int n_dims, int nrows,
12423
12432
  const float inv_ndims = -1.0f / n_dims;
12424
12433
 
12425
12434
  if (pos == nullptr) {
12426
- /*
12427
- DPCT1049:42: The work-group size passed to the SYCL kernel may exceed
12428
- the limit. To get the device limit, query
12429
- info::device::max_work_group_size. Adjust the work-group size if needed.
12430
- */
12431
12435
  dpct::has_capability_or_fail(stream->get_device(),
12432
12436
  {sycl::aspect::fp16});
12433
-
12434
- stream->parallel_for(
12435
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12436
- [=](sycl::nd_item<3> item_ct1) {
12437
- rope_neox<T, false>(x, dst, ncols, n_dims, pos, freq_scale,
12438
- p_delta_rows, ext_factor, attn_factor,
12439
- corr_dims, theta_scale, inv_ndims,
12440
- item_ct1);
12441
- });
12437
+ if (freq_factors == nullptr) {
12438
+ stream->parallel_for(
12439
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12440
+ [=](sycl::nd_item<3> item_ct1) {
12441
+ rope_neox<T, false, false>(x, dst, ncols, n_dims, pos, freq_scale,
12442
+ p_delta_rows, ext_factor, attn_factor,
12443
+ corr_dims, theta_scale, inv_ndims, freq_factors,
12444
+ item_ct1);
12445
+ });
12446
+ } else {
12447
+ stream->parallel_for(
12448
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12449
+ [=](sycl::nd_item<3> item_ct1) {
12450
+ rope_neox<T, false, true>(x, dst, ncols, n_dims, pos, freq_scale,
12451
+ p_delta_rows, ext_factor, attn_factor,
12452
+ corr_dims, theta_scale, inv_ndims, freq_factors,
12453
+ item_ct1);
12454
+ });
12455
+ }
12442
12456
  } else {
12443
- /*
12444
- DPCT1049:43: The work-group size passed to the SYCL kernel may exceed
12445
- the limit. To get the device limit, query
12446
- info::device::max_work_group_size. Adjust the work-group size if needed.
12447
- */
12448
12457
  dpct::has_capability_or_fail(stream->get_device(),
12449
12458
  {sycl::aspect::fp16});
12450
12459
 
12451
- stream->parallel_for(
12452
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
12453
- [=](sycl::nd_item<3> item_ct1) {
12454
- rope_neox<T, true>(x, dst, ncols, n_dims, pos, freq_scale,
12455
- p_delta_rows, ext_factor, attn_factor,
12456
- corr_dims, theta_scale, inv_ndims, item_ct1);
12457
- });
12460
+ if (freq_factors == nullptr) {
12461
+ stream->parallel_for(
12462
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12463
+ [=](sycl::nd_item<3> item_ct1) {
12464
+ rope_neox<T, true, false>(x, dst, ncols, n_dims, pos, freq_scale,
12465
+ p_delta_rows, ext_factor, attn_factor,
12466
+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12467
+ });
12468
+ } else {
12469
+ stream->parallel_for(
12470
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
12471
+ [=](sycl::nd_item<3> item_ct1) {
12472
+ rope_neox<T, true, true>(x, dst, ncols, n_dims, pos, freq_scale,
12473
+ p_delta_rows, ext_factor, attn_factor,
12474
+ corr_dims, theta_scale, inv_ndims, freq_factors, item_ct1);
12475
+ });
12476
+ }
12458
12477
  }
12459
12478
  }
12460
12479
 
12461
- static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
12462
- const int32_t *pos, float freq_scale,
12463
- int p_delta_rows, float freq_base, int n_ctx,
12464
- dpct::queue_ptr stream) {
12465
- GGML_ASSERT(ncols % 4 == 0);
12466
- const sycl::range<3> block_dims(1, 1, SYCL_ROPE_BLOCK_SIZE / 4);
12467
- const int num_blocks_x = (ncols + SYCL_ROPE_BLOCK_SIZE - 1) / SYCL_ROPE_BLOCK_SIZE;
12468
- const sycl::range<3> block_nums(1, nrows, num_blocks_x);
12469
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
12470
- [=](sycl::nd_item<3> item_ct1) {
12471
- rope_glm_f32(x, dst, ncols, pos, freq_scale,
12472
- p_delta_rows, freq_base, n_ctx,
12473
- item_ct1);
12474
- });
12475
- }
12476
-
12477
12480
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
12478
12481
  const int nrows, dpct::queue_ptr stream) {
12479
12482
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -13501,6 +13504,10 @@ inline void ggml_sycl_op_concat(const ggml_tensor *src0,
13501
13504
  const float *src0_dd, const float *src1_dd,
13502
13505
  float *dst_dd,
13503
13506
  const dpct::queue_ptr &main_stream) {
13507
+ #pragma message("TODO: generalize concat kernel for dim != 2")
13508
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563")
13509
+ int dim = dst->op_params[0];
13510
+ GGML_ASSERT(dim == 2);
13504
13511
 
13505
13512
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
13506
13513
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -13986,9 +13993,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
13986
13993
  ggml_tensor *dst, const float *src0_dd,
13987
13994
  const float *src1_dd, float *dst_dd,
13988
13995
  const dpct::queue_ptr &main_stream) {
13989
- #pragma message("TODO: implement phi3 frequency factors support")
13990
- #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
13991
- GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
13996
+ const ggml_tensor * src2 = dst->src[2];
13992
13997
 
13993
13998
  GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
13994
13999
  GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -14002,8 +14007,8 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14002
14007
  //const int n_past = ((int32_t *) dst->op_params)[0];
14003
14008
  const int n_dims = ((int32_t *) dst->op_params)[1];
14004
14009
  const int mode = ((int32_t *) dst->op_params)[2];
14005
- const int n_ctx = ((int32_t *) dst->op_params)[3];
14006
- const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
14010
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
14011
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
14007
14012
 
14008
14013
  // RoPE alteration for extended context
14009
14014
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -14014,6 +14019,7 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14014
14019
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
14015
14020
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
14016
14021
 
14022
+ const float * freq_factors = nullptr;
14017
14023
  const int32_t * pos = nullptr;
14018
14024
  if ((mode & 1) == 0) {
14019
14025
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
@@ -14022,26 +14028,35 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14022
14028
  }
14023
14029
 
14024
14030
  const bool is_neox = mode & 2;
14025
- const bool is_glm = mode & 4;
14031
+
14032
+ #pragma message("TODO: update rope NORM mode to match NEOX mode")
14033
+ #pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
14034
+
14035
+ if (is_neox) {
14036
+ pos = (const int32_t *) src1_dd;
14037
+
14038
+ if (src2 != nullptr) {
14039
+ freq_factors = (const float *) src2->data;
14040
+ }
14041
+ } else {
14042
+ GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for !is_neox");
14043
+ }
14026
14044
 
14027
14045
  rope_corr_dims corr_dims;
14028
- ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
14046
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
14029
14047
 
14030
14048
  // compute
14031
- if (is_glm) {
14032
- GGML_ASSERT(false);
14033
- rope_glm_f32_sycl(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
14034
- } else if (is_neox) {
14049
+ if (is_neox) {
14035
14050
  if (src0->type == GGML_TYPE_F32) {
14036
14051
  rope_neox_sycl(
14037
14052
  (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
14038
- attn_factor, corr_dims, main_stream
14053
+ attn_factor, corr_dims, freq_factors, main_stream
14039
14054
  );
14040
14055
  } else if (src0->type == GGML_TYPE_F16) {
14041
14056
  rope_neox_sycl((const sycl::half *)src0_dd, (sycl::half *)dst_dd,
14042
14057
  ne00, n_dims, nrows, pos, freq_scale, ne01,
14043
14058
  freq_base, ext_factor, attn_factor, corr_dims,
14044
- main_stream);
14059
+ freq_factors, main_stream);
14045
14060
  } else {
14046
14061
  GGML_ASSERT(false);
14047
14062
  }
@@ -15108,7 +15123,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15108
15123
  const int64_t r2 = ne12/ne02;
15109
15124
  const int64_t r3 = ne13/ne03;
15110
15125
 
15111
- if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
15126
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
15112
15127
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
15113
15128
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
15114
15129
  *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
@@ -15173,6 +15188,29 @@ catch (sycl::exception const &exc) {
15173
15188
  std::exit(1);
15174
15189
  }
15175
15190
 
15191
+ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
15192
+ // TODO: accuracy issues in MMQ
15193
+ return false;
15194
+ }
15195
+
15196
+ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
15197
+ switch (type) {
15198
+ case GGML_TYPE_Q4_0:
15199
+ case GGML_TYPE_Q4_1:
15200
+ case GGML_TYPE_Q5_0:
15201
+ case GGML_TYPE_Q5_1:
15202
+ case GGML_TYPE_Q8_0:
15203
+ case GGML_TYPE_Q2_K:
15204
+ case GGML_TYPE_Q3_K:
15205
+ case GGML_TYPE_Q4_K:
15206
+ case GGML_TYPE_Q5_K:
15207
+ case GGML_TYPE_Q6_K:
15208
+ case GGML_TYPE_F16:
15209
+ return true;
15210
+ default:
15211
+ return false;
15212
+ }
15213
+ }
15176
15214
 
15177
15215
  static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
15178
15216
  const bool all_on_device =
@@ -15189,75 +15227,42 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
15189
15227
  }
15190
15228
  }
15191
15229
 
15192
- #ifdef SYCL_USE_XMX
15193
- const bool use_xmx = true;
15194
- #else
15195
- const bool use_xmx = false;
15196
- #endif
15230
+ // check data types and tensor shapes for custom matrix multiplication kernels:
15231
+ bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
15232
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15233
+ && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
15234
+
15235
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
15236
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15237
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
15197
15238
 
15198
- // debug helpers
15199
- //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
15200
- //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
15201
- //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
15202
- //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
15203
- //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
15204
- //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
15239
+ bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
15240
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
15205
15241
 
15206
- if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15242
+ // mmvq and mmq need the __dp4a instruction which is available for gen12+
15243
+ // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
15244
+ use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
15245
+ #ifdef SYCL_USE_XMX
15246
+ use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
15247
+ #endif // SYCL_USE_XMX
15248
+
15249
+ if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15207
15250
  // KQ single-batch
15208
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
15209
15251
  ggml_sycl_mul_mat_vec_p021(src0, src1, dst);
15210
- } else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15252
+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15211
15253
  // KQV single-batch
15212
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
15213
15254
  ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
15214
- } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
15255
+ } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
15215
15256
  // KQ + KQV multi-batch
15216
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
15217
15257
  ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
15218
- } else if (src0->type == GGML_TYPE_F32) {
15219
- // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
15220
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15221
- } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
15222
- // GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n");
15223
- if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) {
15224
- #ifdef GGML_SYCL_FORCE_DMMV
15225
- const bool use_mul_mat_vec_q = false;
15226
- #else
15227
- bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15228
- use_mul_mat_vec_q = use_mul_mat_vec_q ||
15229
- (src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
15230
- (src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
15231
- (src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
15232
- (src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
15233
-
15234
-
15235
- #endif // GGML_SYCL_FORCE_DMMV
15236
-
15237
- if (use_mul_mat_vec_q) {
15238
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
15239
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15240
- } else {
15241
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n");
15242
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15243
- }
15244
- } else {
15245
- bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15246
-
15247
- if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
15248
- use_mul_mat_q = false;
15249
- }
15250
-
15251
- if (use_mul_mat_q) {
15252
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n");
15253
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15254
- } else {
15255
- // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
15256
- ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15257
- }
15258
- }
15258
+ } else if (use_dequantize_mul_mat_vec) {
15259
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15260
+ } else if (use_mul_mat_vec_q) {
15261
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15262
+ } else if (use_mul_mat_q) {
15263
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15259
15264
  } else {
15260
- GGML_ASSERT(false);
15265
+ ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15261
15266
  }
15262
15267
  }
15263
15268
 
@@ -15434,22 +15439,86 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
15434
15439
  }
15435
15440
  #endif
15436
15441
 
15442
+ struct mmid_row_mapping {
15443
+ int32_t i1;
15444
+ int32_t i2;
15445
+ };
15446
+
15447
+ __dpct_inline__ static void k_copy_src1_to_contiguous(
15448
+ const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
15449
+ int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
15450
+ const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
15451
+ int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
15452
+ const sycl::nd_item<3> &item_ct1, int &src1_row) {
15453
+ int32_t iid1 = item_ct1.get_group(2);
15454
+ int32_t id = item_ct1.get_group(1);
15455
+
15456
+ const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
15457
+
15458
+ if (row_id_i != i02) {
15459
+ return;
15460
+ }
15461
+
15462
+ const int64_t i11 = id % ne11;
15463
+ const int64_t i12 = iid1;
15464
+
15465
+ if (item_ct1.get_local_id(2) == 0) {
15466
+ src1_row =
15467
+ dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
15468
+ cur_src1_row, 1);
15469
+ row_mapping[src1_row] = {id, iid1};
15470
+ }
15471
+ /*
15472
+ DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
15473
+ sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
15474
+ performance if there is no access to global memory.
15475
+ */
15476
+ item_ct1.barrier();
15477
+
15478
+ const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
15479
+ float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
15480
+
15481
+ #pragma unroll
15482
+ for (int i = item_ct1.get_local_id(2); i < ne10;
15483
+ i += item_ct1.get_local_range(2)) {
15484
+ src1_row_contiguous[i] = src1_row_original[i];
15485
+ }
15486
+ }
15487
+
15488
+ __dpct_inline__ static void k_copy_dst_from_contiguous(
15489
+ char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
15490
+ const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
15491
+ size_t nb2, const sycl::nd_item<3> &item_ct1) {
15492
+ int32_t i = item_ct1.get_group(2);
15493
+
15494
+ const int32_t i1 = row_mapping[i].i1;
15495
+ const int32_t i2 = row_mapping[i].i2;
15496
+
15497
+ const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
15498
+ float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
15499
+
15500
+ #pragma unroll
15501
+ for (int j = item_ct1.get_local_id(2); j < ne0;
15502
+ j += item_ct1.get_local_range(2)) {
15503
+ dst_row_original[j] = dst_row_contiguous[j];
15504
+ }
15505
+ }
15506
+
15437
15507
  static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15438
15508
  const ggml_tensor *src1,
15439
15509
  ggml_tensor *dst) try {
15440
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
15441
- "mul_mat_id does not support split buffers");
15510
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
15511
+
15442
15512
  const ggml_tensor *ids = dst->src[2];
15443
- const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15513
+ GGML_TENSOR_BINARY_OP_LOCALS
15444
15514
 
15445
- const size_t nb11 = src1->nb[1];
15446
- const size_t nb1 = dst->nb[1];
15515
+ const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
15447
15516
 
15448
- const int32_t id = ((int32_t *)dst->op_params)[0];
15449
- const int32_t n_as = src0->ne[2];
15517
+ const int64_t n_as = ne02;
15518
+ const int64_t n_ids = ids->ne[0];
15450
15519
 
15451
15520
  std::vector<char> ids_host(ggml_nbytes(ids));
15452
- const char *ids_dev = (const char *)ids->data;
15521
+ const char * ids_dev = (const char *) ids->data;
15453
15522
 
15454
15523
  SYCL_CHECK(CHECK_TRY_ERROR(
15455
15524
  stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
@@ -15489,24 +15558,40 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15489
15558
 
15490
15559
  src0_row.ne[2] = 1;
15491
15560
  src0_row.ne[3] = 1;
15492
- src0_row.nb[3] = src0->nb[2];
15493
-
15494
- if (src1->ne[1] == 1) {
15495
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15496
- const int32_t row_id =
15497
- *(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
15498
- id * ids->nb[0]);
15499
-
15500
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15561
+ src0_row.nb[3] = nb02;
15562
+
15563
+ src1_row.ne[1] = 1;
15564
+ src1_row.ne[2] = 1;
15565
+ src1_row.ne[3] = 1;
15566
+ src1_row.nb[2] = nb11;
15567
+ src1_row.nb[3] = nb11;
15568
+
15569
+ dst_row.ne[1] = 1;
15570
+ dst_row.ne[2] = 1;
15571
+ dst_row.ne[3] = 1;
15572
+ dst_row.nb[2] = nb1;
15573
+ dst_row.nb[3] = nb1;
15574
+ if (ne12 == 1) {
15575
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15576
+ for (int64_t id = 0; id < n_ids; id++) {
15577
+ const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
15578
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
15579
+
15580
+ const int64_t i11 = id % ne11;
15581
+ const int64_t i12 = iid1;
15582
+
15583
+ const int64_t i1 = id;
15584
+ const int64_t i2 = i12;
15501
15585
 
15502
15586
  src0_row_extra.data_device[g_main_device] =
15503
- src0_original + row_id * src0->nb[2];
15587
+ src0_original + i02*nb02;
15504
15588
  src1_row_extra.data_device[g_main_device] =
15505
- src1_original + i01 * src1->nb[1];
15589
+ src1_original + + i11*nb11 + i12*nb12;
15506
15590
  dst_row_extra.data_device[g_main_device] =
15507
- dst_original + i01 * dst->nb[1];
15591
+ dst_original + i1*nb1 + i2*nb2;
15508
15592
 
15509
15593
  ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
15594
+ }
15510
15595
  }
15511
15596
  } else {
15512
15597
  sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
@@ -15515,64 +15600,98 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
15515
15600
  src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
15516
15601
  dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
15517
15602
 
15518
- for (int32_t row_id = 0; row_id < n_as; ++row_id) {
15603
+ for (int64_t i02 = 0; i02 < n_as; i02++) {
15519
15604
  int64_t num_src1_rows = 0;
15520
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15521
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
15605
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
15606
+ for (int64_t id = 0; id < n_ids; id++) {
15607
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
15522
15608
 
15523
- if (row_id_i != row_id) {
15524
- continue;
15525
- }
15609
+ GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
15526
15610
 
15527
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15611
+ if (row_id_i != i02) {
15612
+ continue;
15613
+ }
15528
15614
 
15529
- SYCL_CHECK(CHECK_TRY_ERROR(
15530
- stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
15531
- src1_original + i01 * nb11, nb11)));
15532
- num_src1_rows++;
15615
+ num_src1_rows++;
15616
+ }
15533
15617
  }
15534
15618
 
15535
15619
  if (num_src1_rows == 0) {
15536
15620
  continue;
15537
15621
  }
15538
15622
 
15539
- src0_row_extra.data_device[g_main_device] =
15540
- src0_original + row_id * src0->nb[2];
15541
15623
 
15624
+ sycl_pool_alloc<int> dev_cur_src1_row(1);
15625
+ sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(num_src1_rows);
15626
+ SYCL_CHECK(CHECK_TRY_ERROR(
15627
+ stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
15628
+
15629
+ {
15630
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
15631
+ sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
15632
+ stream->submit([&](sycl::handler &cgh) {
15633
+ sycl::local_accessor<int, 0> src1_row_acc(cgh);
15634
+
15635
+ char *__restrict src1_contiguous_get =
15636
+ src1_contiguous.get();
15637
+ int *__restrict dev_cur_src1_row_get =
15638
+ dev_cur_src1_row.get();
15639
+ mmid_row_mapping *__restrict dev_row_mapping_get =
15640
+ dev_row_mapping.get();
15641
+ size_t ids_nb_ct6 = ids->nb[1];
15642
+ size_t ids_nb_ct7 = ids->nb[0];
15643
+
15644
+ cgh.parallel_for(
15645
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15646
+ [=](sycl::nd_item<3> item_ct1) {
15647
+ k_copy_src1_to_contiguous(
15648
+ src1_original, src1_contiguous_get,
15649
+ dev_cur_src1_row_get,
15650
+ dev_row_mapping_get, ids_dev, i02,
15651
+ ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
15652
+ item_ct1, src1_row_acc);
15653
+ });
15654
+ });
15655
+ }
15656
+
15657
+ src0_row_extra.data_device[g_main_device] = src0_original + i02*nb02;
15658
+
15659
+ GGML_ASSERT(nb11 == sizeof(float)*ne10);
15660
+ GGML_ASSERT(nb1 == sizeof(float)*ne0);
15542
15661
  src1_row.ne[1] = num_src1_rows;
15543
- dst_row.ne[1] = num_src1_rows;
15544
15662
 
15545
15663
  src1_row.nb[1] = nb11;
15546
15664
  src1_row.nb[2] = num_src1_rows*nb11;
15547
15665
  src1_row.nb[3] = num_src1_rows*nb11;
15548
15666
 
15667
+ dst_row.ne[1] = num_src1_rows;
15549
15668
  dst_row.nb[1] = nb1;
15550
15669
  dst_row.nb[2] = num_src1_rows*nb1;
15551
15670
  dst_row.nb[3] = num_src1_rows*nb1;
15552
15671
 
15553
15672
  ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
15554
15673
 
15555
- num_src1_rows = 0;
15556
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
15557
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
15558
-
15559
- if (row_id_i != row_id) {
15560
- continue;
15561
- }
15562
-
15563
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
15564
-
15565
- SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
15566
- dst_original + i01 * nb1,
15567
- dst_contiguous.get() + num_src1_rows * nb1, nb1)));
15568
- num_src1_rows++;
15674
+ {
15675
+ sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
15676
+ sycl::range<3> grid_dims(1, 1, num_src1_rows);
15677
+ stream->submit([&](sycl::handler &cgh) {
15678
+ const char *__restrict dst_contiguous_get =
15679
+ dst_contiguous.get();
15680
+ const mmid_row_mapping *__restrict dev_row_mapping_get =
15681
+ dev_row_mapping.get();
15682
+
15683
+ cgh.parallel_for(
15684
+ sycl::nd_range<3>(grid_dims * block_dims, block_dims),
15685
+ [=](sycl::nd_item<3> item_ct1) {
15686
+ k_copy_dst_from_contiguous(dst_original,
15687
+ dst_contiguous_get,
15688
+ dev_row_mapping_get,
15689
+ ne0, nb1, nb2, item_ct1);
15690
+ });
15691
+ });
15569
15692
  }
15570
15693
  }
15571
15694
  }
15572
-
15573
- if (dst->backend == GGML_BACKEND_TYPE_CPU) {
15574
- SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
15575
- }
15576
15695
  }
15577
15696
  catch (sycl::exception const &exc) {
15578
15697
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -16555,10 +16674,9 @@ GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backe
16555
16674
  UNUSED(buffer);
16556
16675
  }
16557
16676
 
16558
- // unused at the moment
16559
- //static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
16560
- // return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
16561
- //}
16677
+ static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
16678
+ return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
16679
+ }
16562
16680
 
16563
16681
  GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
16564
16682
  ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;