@fugood/llama.node 0.3.13 → 0.3.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +1 -1
  18. package/package.json +1 -1
  19. package/src/LlamaContext.cpp +98 -76
  20. package/src/LlamaContext.h +1 -1
  21. package/src/common.hpp +1 -2
  22. package/src/llama.cpp/.github/workflows/build.yml +89 -10
  23. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  24. package/src/llama.cpp/CMakeLists.txt +9 -1
  25. package/src/llama.cpp/cmake/common.cmake +2 -0
  26. package/src/llama.cpp/common/CMakeLists.txt +3 -3
  27. package/src/llama.cpp/common/arg.cpp +132 -13
  28. package/src/llama.cpp/common/chat.cpp +960 -266
  29. package/src/llama.cpp/common/chat.h +135 -0
  30. package/src/llama.cpp/common/common.cpp +33 -174
  31. package/src/llama.cpp/common/common.h +27 -67
  32. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  33. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  34. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
  35. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  36. package/src/llama.cpp/common/sampling.cpp +45 -7
  37. package/src/llama.cpp/common/speculative.cpp +10 -9
  38. package/src/llama.cpp/common/speculative.h +1 -1
  39. package/src/llama.cpp/docs/build.md +45 -7
  40. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  41. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +4 -2
  42. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -1
  43. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  44. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  45. package/src/llama.cpp/examples/imatrix/imatrix.cpp +3 -4
  46. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  48. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +5 -5
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  50. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  51. package/src/llama.cpp/examples/llava/clip.h +19 -3
  52. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  53. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  54. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  55. package/src/llama.cpp/examples/lookahead/lookahead.cpp +7 -6
  56. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  57. package/src/llama.cpp/examples/main/main.cpp +79 -34
  58. package/src/llama.cpp/examples/parallel/parallel.cpp +6 -5
  59. package/src/llama.cpp/examples/passkey/passkey.cpp +15 -14
  60. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  61. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  62. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  63. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  64. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  65. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  66. package/src/llama.cpp/examples/run/run.cpp +196 -108
  67. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  68. package/src/llama.cpp/examples/server/server.cpp +113 -101
  69. package/src/llama.cpp/examples/server/utils.hpp +94 -105
  70. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  74. package/src/llama.cpp/examples/tts/tts.cpp +263 -151
  75. package/src/llama.cpp/ggml/CMakeLists.txt +14 -1
  76. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  77. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  78. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  79. package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
  80. package/src/llama.cpp/ggml/include/ggml.h +29 -1
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -34
  82. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  83. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  84. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  85. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  87. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -7
  88. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  89. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +139 -16
  90. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  91. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  93. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1546 -387
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1645 -113
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
  97. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  101. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
  102. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  103. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  104. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  105. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  106. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  107. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +242 -0
  108. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -6
  109. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  110. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +315 -138
  111. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  112. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  113. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +5 -0
  114. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  116. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +117 -36
  117. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  118. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  119. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  120. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +147 -16
  123. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  124. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +307 -0
  125. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  126. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +262 -746
  127. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  128. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -78
  129. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  130. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +4 -1
  132. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  134. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  135. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  136. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +498 -188
  137. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  138. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +16 -3
  139. package/src/llama.cpp/ggml/src/ggml.c +93 -5
  140. package/src/llama.cpp/include/llama.h +105 -27
  141. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  142. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  143. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  144. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  145. package/src/llama.cpp/requirements.txt +1 -0
  146. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  147. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  148. package/src/llama.cpp/src/llama-adapter.h +11 -9
  149. package/src/llama.cpp/src/llama-arch.cpp +123 -16
  150. package/src/llama.cpp/src/llama-arch.h +19 -0
  151. package/src/llama.cpp/src/llama-batch.h +2 -2
  152. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  153. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  154. package/src/llama.cpp/src/llama-context.h +214 -77
  155. package/src/llama.cpp/src/llama-cparams.h +1 -0
  156. package/src/llama.cpp/src/llama-grammar.cpp +182 -182
  157. package/src/llama.cpp/src/llama-grammar.h +12 -3
  158. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  159. package/src/llama.cpp/src/llama-graph.h +574 -0
  160. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  161. package/src/llama.cpp/src/llama-hparams.h +9 -0
  162. package/src/llama.cpp/src/llama-io.cpp +15 -0
  163. package/src/llama.cpp/src/llama-io.h +35 -0
  164. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  165. package/src/llama.cpp/src/llama-kv-cache.h +178 -109
  166. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  167. package/src/llama.cpp/src/llama-memory.h +21 -0
  168. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  169. package/src/llama.cpp/src/llama-model.cpp +8230 -122
  170. package/src/llama.cpp/src/llama-model.h +34 -1
  171. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  172. package/src/llama.cpp/src/llama-sampling.cpp +43 -10
  173. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  174. package/src/llama.cpp/src/llama.cpp +51 -9837
  175. package/src/llama.cpp/tests/test-backend-ops.cpp +247 -112
  176. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  177. package/src/llama.cpp/tests/test-chat.cpp +593 -395
  178. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  179. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  180. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  181. package/src/llama.cpp/common/chat.hpp +0 -55
  182. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  183. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
  184. /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
@@ -39,8 +39,14 @@
39
39
  #include "ggml-sycl/backend.hpp"
40
40
  #include "ggml-sycl/presets.hpp"
41
41
  #include "ggml-sycl/gemm.hpp"
42
+ #include "ggml-sycl/sycl_hw.hpp"
43
+ #include "ggml-sycl/getrows.hpp"
44
+ #include "ggml.h"
42
45
 
43
46
  static bool g_sycl_loaded = false;
47
+ int g_ggml_sycl_debug = 0;
48
+ int g_ggml_sycl_disable_optimize = 0;
49
+ int g_ggml_sycl_disable_graph = 0;
44
50
 
45
51
  static ggml_sycl_device_info ggml_sycl_init() {
46
52
  ggml_sycl_device_info info = {};
@@ -63,14 +69,18 @@ static ggml_sycl_device_info ggml_sycl_init() {
63
69
  for (int i = 0; i < info.device_count; ++i) {
64
70
  info.devices[i].vmm = 0;
65
71
  dpct::device_info prop;
72
+ sycl::device device = dpct::dev_mgr::instance().get_device(i);
73
+
66
74
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
67
- prop, dpct::dev_mgr::instance().get_device(i))));
75
+ prop, device)));
68
76
 
69
77
  info.default_tensor_split[i] = total_vram;
70
78
  total_vram += prop.get_global_mem_size();
71
79
 
72
80
  info.devices[i].cc =
73
81
  100 * prop.get_major_version() + 10 * prop.get_minor_version();
82
+ info.devices[i].hw_info = get_device_hw_info(&device);
83
+ info.devices[i].opt_feature = check_gpu_optimize_feature(info.devices[i].hw_info.arch);
74
84
 
75
85
  info.max_work_group_sizes[i] = prop.get_max_work_group_size();
76
86
  }
@@ -86,7 +96,7 @@ const ggml_sycl_device_info & ggml_sycl_info() {
86
96
  return info;
87
97
  }
88
98
 
89
- void print_device_detail(int id, sycl::device &device, std::string device_type) {
99
+ static void print_device_detail(int id, sycl::device &device, std::string device_type) {
90
100
 
91
101
  dpct::device_info prop;
92
102
  SYCL_CHECK(CHECK_TRY_ERROR(
@@ -109,6 +119,27 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
109
119
  global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
110
120
  }
111
121
 
122
+ static void print_device_opt_feature(int device_count) {
123
+ GGML_LOG_INFO("SYCL Optimization Feature:\n");
124
+ GGML_LOG_INFO(
125
+ "|ID| Device Type|Reorder|\n");
126
+ GGML_LOG_INFO(
127
+ "|--|-------------------|-------|\n");
128
+ std::map<std::string, size_t> DeviceNums;
129
+ for (int id = 0; id < device_count; ++id) {
130
+ sycl::device device = dpct::dev_mgr::instance().get_device(id);
131
+ std::string backend_type = get_device_backend_and_type(device);
132
+ int type_id = DeviceNums[backend_type]++;
133
+ std::stringstream device_type;
134
+ device_type << "[" << backend_type << ":" << std::to_string(type_id)
135
+ << "]";
136
+ std::string device_type_s = device_type.str();
137
+ device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), "");
138
+ GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(),
139
+ ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N");
140
+ }
141
+
142
+ }
112
143
  void ggml_backend_sycl_print_sycl_devices() {
113
144
  GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
114
145
  int device_count = dpct::dev_mgr::instance().device_count();
@@ -137,6 +168,8 @@ void ggml_backend_sycl_print_sycl_devices() {
137
168
  << "]";
138
169
  print_device_detail(id, device, device_type.str());
139
170
  }
171
+
172
+ print_device_opt_feature(device_count);
140
173
  }
141
174
 
142
175
  static inline int get_sycl_env(const char *env_name, int default_val) {
@@ -157,18 +190,24 @@ static void ggml_check_sycl() try {
157
190
  static bool initialized = false;
158
191
 
159
192
  if (!initialized) {
160
- GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
161
193
  g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
162
- GGML_LOG_INFO("GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
194
+ g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
195
+ g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
196
+ GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
197
+ GGML_LOG_INFO("Running with Environment Variables:\n");
198
+ GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
199
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
200
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
201
+ GGML_LOG_INFO("Build with Macros:\n");
163
202
  #if defined(GGML_SYCL_FORCE_MMQ)
164
- GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: yes\n");
203
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
165
204
  #else
166
- GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: no\n");
205
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
167
206
  #endif
168
207
  #if defined(GGML_SYCL_F16)
169
- GGML_LOG_INFO("GGML_SYCL_F16: yes\n");
208
+ GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
170
209
  #else
171
- GGML_LOG_INFO("GGML_SYCL_F16: no\n");
210
+ GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
172
211
  #endif
173
212
 
174
213
  /* NOT REMOVE, keep it for next optimize for XMX.
@@ -240,19 +279,27 @@ struct ggml_backend_sycl_buffer_context {
240
279
  void * dev_ptr = nullptr;
241
280
  queue_ptr stream;
242
281
  std::string name;
282
+ optimize_feature opt_feature;
283
+ std::vector<ggml_tensor_extra_gpu *> tensor_extras;
243
284
 
244
- ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
285
+ ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
245
286
  device(device), dev_ptr(dev_ptr), stream(stream) {
246
287
  check_allow_gpu_index(device);
247
288
  name = (GGML_SYCL_NAME + std::to_string(device));
289
+ opt_feature = ggml_sycl_info().devices[device].opt_feature;
248
290
  }
249
291
 
250
-
251
292
  ~ggml_backend_sycl_buffer_context() {
252
293
  if (dev_ptr != nullptr) {
253
294
  ggml_sycl_set_device(device);
254
295
  SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
255
296
  }
297
+
298
+ //release extra used by tensors
299
+ for (ggml_tensor_extra_gpu * extra : tensor_extras) {
300
+ release_extra_gpu(extra);
301
+ }
302
+
256
303
  }
257
304
  };
258
305
 
@@ -280,16 +327,20 @@ static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
280
327
  return ctx->dev_ptr;
281
328
  }
282
329
 
283
- static void
330
+ static enum ggml_status
284
331
  ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
285
332
  ggml_tensor *tensor) try {
286
333
  ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
287
334
 
288
335
  if (tensor->view_src != NULL) {
289
336
  assert(tensor->view_src->buffer->buft == buffer->buft);
290
- return;
337
+ return GGML_STATUS_SUCCESS;
338
+ }
339
+ if (tensor->type == GGML_TYPE_Q4_0) {
340
+ ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
341
+ tensor->extra = extra;
342
+ ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
291
343
  }
292
-
293
344
 
294
345
  if (ggml_is_quantized(tensor->type)) {
295
346
  // initialize padding to 0 to avoid possible NaN values
@@ -302,6 +353,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
302
353
  padded_size - original_size).wait()));
303
354
  }
304
355
  }
356
+ return GGML_STATUS_SUCCESS;
305
357
  }
306
358
  catch (sycl::exception const &exc) {
307
359
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -315,7 +367,6 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
315
367
  size_t size) try {
316
368
 
317
369
  ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
318
-
319
370
  ggml_sycl_set_device(ctx->device);
320
371
  auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
321
372
  SYCL_CHECK(
@@ -353,7 +404,7 @@ catch (sycl::exception const &exc) {
353
404
  std::exit(1);
354
405
  }
355
406
 
356
- void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
407
+ static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
357
408
  const void *ptr_src, size_t size) {
358
409
  char *host_buf = (char *)malloc(size);
359
410
  q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
@@ -439,6 +490,22 @@ catch (sycl::exception const &exc) {
439
490
  std::exit(1);
440
491
  }
441
492
 
493
+ static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
494
+ GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
495
+ if (buffer == nullptr) {
496
+ return;
497
+ }
498
+
499
+ ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
500
+
501
+ if (ctx != nullptr) {
502
+ for (ggml_tensor_extra_gpu * extra : ctx->tensor_extras) {
503
+ release_extra_gpu(extra);
504
+ }
505
+ ctx->tensor_extras.clear(); // reset the tensor_extras vector
506
+ }
507
+ }
508
+
442
509
  static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
443
510
  /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
444
511
  /* .get_base = */ ggml_backend_sycl_buffer_get_base,
@@ -448,7 +515,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
448
515
  /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
449
516
  /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
450
517
  /* .clear = */ ggml_backend_sycl_buffer_clear,
451
- /* .reset = */ NULL,
518
+ /* .reset = */ ggml_backend_sycl_buffer_reset,
452
519
  };
453
520
 
454
521
  // sycl buffer type
@@ -529,7 +596,6 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
529
596
  static std::mutex mutex;
530
597
  std::lock_guard<std::mutex> lock(mutex);
531
598
 
532
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
533
599
 
534
600
  auto dev_count = ggml_backend_sycl_get_device_count();
535
601
 
@@ -557,7 +623,7 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
557
623
  return &ggml_backend_sycl_buffer_types[device];
558
624
  }
559
625
 
560
- ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
626
+ static ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
561
627
  GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
562
628
 
563
629
  int device = ctx->device;
@@ -659,32 +725,7 @@ struct ggml_backend_sycl_split_buffer_type_context {
659
725
  struct ggml_backend_sycl_split_buffer_context {
660
726
  ~ggml_backend_sycl_split_buffer_context() try {
661
727
  for (ggml_tensor_extra_gpu * extra : tensor_extras) {
662
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
663
- for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
664
- if (extra->events[i][is] != nullptr) {
665
- /*
666
- DPCT1009:206: SYCL uses exceptions to report errors and
667
- does not use the error codes. The original code was
668
- commented out and a warning string was inserted. You
669
- need to rewrite this code.
670
- */
671
- SYCL_CHECK(CHECK_TRY_ERROR(
672
- dpct::destroy_event(extra->events[i][is])));
673
- }
674
- }
675
- if (extra->data_device[i] != nullptr) {
676
- /*
677
- DPCT1009:207: SYCL uses exceptions to report errors and does
678
- not use the error codes. The original code was commented out
679
- and a warning string was inserted. You need to rewrite this
680
- code.
681
- */
682
- ggml_sycl_set_device(i);
683
- SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(
684
- extra->data_device[i], *(streams[i]))));
685
- }
686
- }
687
- delete extra;
728
+ release_extra_gpu(extra, streams);
688
729
  }
689
730
  }
690
731
  catch (sycl::exception const &exc) {
@@ -709,7 +750,7 @@ static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buff
709
750
  GGML_UNUSED(buffer);
710
751
  }
711
752
 
712
- static void
753
+ static enum ggml_status
713
754
  ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
714
755
  ggml_tensor *tensor) try {
715
756
  GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
@@ -722,7 +763,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
722
763
  ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
723
764
 
724
765
  ctx->tensor_extras.push_back(extra);
725
- ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
766
+ ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
726
767
 
727
768
  for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
728
769
  int64_t row_low, row_high;
@@ -784,6 +825,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
784
825
  }
785
826
  }
786
827
  tensor->extra = extra;
828
+ return GGML_STATUS_SUCCESS;
787
829
  }
788
830
  catch (sycl::exception const &exc) {
789
831
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -1263,8 +1305,6 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(q
1263
1305
  // struct ggml_sycl_pool_vmm : public ggml_sycl_pool
1264
1306
 
1265
1307
  /// kernels
1266
-
1267
- typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
1268
1308
  typedef void (*ggml_sycl_op_mul_mat_t)(
1269
1309
  ggml_backend_sycl_context & ctx,
1270
1310
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -1336,83 +1376,6 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
1336
1376
  reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
1337
1377
  }
1338
1378
 
1339
- template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
1340
- static void k_get_rows(
1341
- const void * src0, const int32_t * src1, dst_t * dst,
1342
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1343
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1344
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1345
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1346
- size_t s10, size_t s11, size_t s12,
1347
- const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
1348
-
1349
- const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
1350
- item_ct1.get_local_id(2)) *
1351
- 2;
1352
- const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1353
- item_ct1.get_local_id(1);
1354
- const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1355
- item_ct1.get_local_id(0)) /
1356
- ne12;
1357
- const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1358
- item_ct1.get_local_id(0)) %
1359
- ne12;
1360
-
1361
- if (i00 >= ne00) {
1362
- return;
1363
- }
1364
-
1365
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1366
-
1367
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1368
- const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
1369
-
1370
- const int ib = i00/qk; // block index
1371
- const int iqs = (i00%qk)/qr; // quant index
1372
- const int iybs = i00 - i00%qk; // dst block start index
1373
- const int y_offset = qr == 1 ? 1 : qk/2;
1374
-
1375
- // dequantize
1376
- dfloat2 v;
1377
- dequantize_kernel(src0_row, ib, iqs, v);
1378
-
1379
- dst_row[iybs + iqs + 0] = v.x();
1380
- dst_row[iybs + iqs + y_offset] = v.y();
1381
- }
1382
-
1383
- template<typename src0_t, typename dst_t>
1384
- static void k_get_rows_float(
1385
- const src0_t * src0, const int32_t * src1, dst_t * dst,
1386
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1387
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1388
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1389
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1390
- size_t s10, size_t s11, size_t s12,
1391
- const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
1392
-
1393
- const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
1394
- item_ct1.get_local_id(2);
1395
- const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1396
- item_ct1.get_local_id(1);
1397
- const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1398
- item_ct1.get_local_id(0)) /
1399
- ne12;
1400
- const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1401
- item_ct1.get_local_id(0)) %
1402
- ne12;
1403
-
1404
- if (i00 >= ne00) {
1405
- return;
1406
- }
1407
-
1408
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1409
-
1410
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1411
- const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
1412
-
1413
- dst_row[i00] = src0_row[i00];
1414
- }
1415
-
1416
1379
  static void mul_mat_p021_f16_f32(
1417
1380
  const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
1418
1381
  const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1523,193 +1486,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1523
1486
  }
1524
1487
  }
1525
1488
 
1526
- static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
1527
- const float * xi = (const float *) cxi;
1528
- float * dsti = (float *) cdsti;
1529
-
1530
- *dsti = *xi;
1531
- }
1532
-
1533
- static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
1534
- const float * xi = (const float *) cxi;
1535
- sycl::half *dsti = (sycl::half *)cdsti;
1536
-
1537
- *dsti = sycl::vec<float, 1>(*xi)
1538
- .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
1539
- }
1540
-
1541
- static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
1542
- const sycl::half *xi = (const sycl::half *)cxi;
1543
- sycl::half *dsti = (sycl::half *)cdsti;
1544
-
1545
- *dsti = *xi;
1546
- }
1547
-
1548
- static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
1549
- const sycl::half *xi = (const sycl::half *)cxi;
1550
- float * dsti = (float *) cdsti;
1551
-
1552
- *dsti = *xi;
1553
- }
1554
-
1555
- static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
1556
- const int16_t *xi = (const int16_t *)cxi;
1557
- int16_t *dsti = (int16_t *)cdsti;
1558
-
1559
- *dsti = *xi;
1560
- }
1561
-
1562
- static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
1563
- const int32_t *xi = (const int32_t *)cxi;
1564
- int32_t *dsti = (int32_t *)cdsti;
1565
-
1566
- *dsti = *xi;
1567
- }
1568
-
1569
- template <cpy_kernel_t cpy_1>
1570
- static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
1571
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
1572
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
1573
- const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
1574
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1575
- item_ct1.get_local_id(2);
1576
-
1577
- if (i >= ne) {
1578
- return;
1579
- }
1580
-
1581
- // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
1582
- // then combine those indices with the corresponding byte offsets to get the total offsets
1583
- const int i03 = i/(ne00 * ne01 * ne02);
1584
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
1585
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
1586
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
1587
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
1588
-
1589
- const int i13 = i/(ne10 * ne11 * ne12);
1590
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
1591
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
1592
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
1593
- const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
1594
-
1595
- cpy_1(cx + x_offset, cdst + dst_offset);
1596
- }
1597
-
1598
- static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
1599
- const float * xi = (const float *) cxi;
1600
- block_q8_0 * dsti = (block_q8_0 *) cdsti;
1601
-
1602
- float amax = 0.0f; // absolute max
1603
-
1604
- for (int j = 0; j < QK8_0; j++) {
1605
- const float v = xi[j];
1606
- amax = sycl::fmax(amax, sycl::fabs((float)v));
1607
- }
1608
-
1609
- const float d = amax / ((1 << 7) - 1);
1610
- const float id = d ? 1.0f/d : 0.0f;
1611
-
1612
- dsti->d = d;
1613
-
1614
- for (int j = 0; j < QK8_0; ++j) {
1615
- const float x0 = xi[j]*id;
1616
-
1617
- dsti->qs[j] = sycl::round((float)x0);
1618
- }
1619
- }
1620
-
1621
- static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
1622
- const float * xi = (const float *) cxi;
1623
- block_q4_0 * dsti = (block_q4_0 *) cdsti;
1624
-
1625
- float amax = 0.0f;
1626
- float vmax = 0.0f;
1627
-
1628
- for (int j = 0; j < QK4_0; ++j) {
1629
- const float v = xi[j];
1630
- if (amax < sycl::fabs((float)v)) {
1631
- amax = sycl::fabs((float)v);
1632
- vmax = v;
1633
- }
1634
- }
1635
-
1636
- const float d = vmax / -8;
1637
- const float id = d ? 1.0f/d : 0.0f;
1638
-
1639
- dsti->d = d;
1640
-
1641
- for (int j = 0; j < QK4_0/2; ++j) {
1642
- const float x0 = xi[0 + j]*id;
1643
- const float x1 = xi[QK4_0/2 + j]*id;
1644
-
1645
- const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
1646
- const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
1647
-
1648
- dsti->qs[j] = xi0;
1649
- dsti->qs[j] |= xi1 << 4;
1650
- }
1651
- }
1652
-
1653
- static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
1654
- const float * xi = (const float *) cxi;
1655
- block_q4_1 * dsti = (block_q4_1 *) cdsti;
1656
-
1657
- float vmin = FLT_MAX;
1658
- float vmax = -FLT_MAX;
1659
-
1660
- for (int j = 0; j < QK4_1; ++j) {
1661
- const float v = xi[j];
1662
-
1663
- if (v < vmin) vmin = v;
1664
- if (v > vmax) vmax = v;
1665
- }
1666
-
1667
- const float d = (vmax - vmin) / ((1 << 4) - 1);
1668
- const float id = d ? 1.0f/d : 0.0f;
1669
-
1670
- dsti->dm.x() = d;
1671
- dsti->dm.y() = vmin;
1672
-
1673
- for (int j = 0; j < QK4_1/2; ++j) {
1674
- const float x0 = (xi[0 + j] - vmin)*id;
1675
- const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
1676
-
1677
- const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
1678
- const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
1679
-
1680
- dsti->qs[j] = xi0;
1681
- dsti->qs[j] |= xi1 << 4;
1682
- }
1683
- }
1684
-
1685
- template <cpy_kernel_t cpy_blck, int qk>
1686
- static void cpy_f32_q(const char * cx, char * cdst, const int ne,
1687
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
1688
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
1689
- const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
1690
- const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1691
- item_ct1.get_local_id(2)) *
1692
- qk;
1693
-
1694
- if (i >= ne) {
1695
- return;
1696
- }
1697
-
1698
- const int i03 = i/(ne00 * ne01 * ne02);
1699
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
1700
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
1701
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
1702
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
1703
-
1704
- const int i13 = i/(ne10 * ne11 * ne12);
1705
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
1706
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
1707
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
1708
- const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
1709
-
1710
- cpy_blck(cx + x_offset, cdst + dst_offset);
1711
- }
1712
-
1713
1489
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
1714
1490
  const sycl::nd_item<3> &item_ct1) {
1715
1491
  const int row = item_ct1.get_group(1);
@@ -1895,81 +1671,6 @@ static void pool2d_nchw_kernel(
1895
1671
  o_ptr[cur_oh * ow + cur_ow] = res;
1896
1672
  }
1897
1673
 
1898
- template <int qk, int qr, dequantize_kernel_t dq>
1899
- static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
1900
- ggml_tensor *dst, const void *src0_dd,
1901
- const int32_t *src1_dd, float *dst_dd,
1902
- queue_ptr stream) {
1903
-
1904
- GGML_TENSOR_BINARY_OP_LOCALS
1905
-
1906
- const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
1907
- const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
1908
- const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
1909
-
1910
- // strides in elements
1911
- //const size_t s0 = nb0 / ggml_element_size(dst);
1912
- const size_t s1 = nb1 / ggml_element_size(dst);
1913
- const size_t s2 = nb2 / ggml_element_size(dst);
1914
- const size_t s3 = nb3 / ggml_element_size(dst);
1915
-
1916
- const size_t s10 = nb10 / ggml_element_size(src1);
1917
- const size_t s11 = nb11 / ggml_element_size(src1);
1918
- const size_t s12 = nb12 / ggml_element_size(src1);
1919
- //const size_t s13 = nb13 / ggml_element_size(src1);
1920
-
1921
- GGML_ASSERT(ne00 % 2 == 0);
1922
-
1923
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
1924
- [=](sycl::nd_item<3> item_ct1) {
1925
- k_get_rows<qk, qr, dq>(
1926
- src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
1927
- s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
1928
- });
1929
-
1930
- GGML_UNUSED(dst);
1931
- GGML_UNUSED(ctx);
1932
- }
1933
-
1934
- template <typename src0_t>
1935
- static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
1936
- const ggml_tensor *src1, ggml_tensor *dst,
1937
- const src0_t *src0_dd, const int32_t *src1_dd,
1938
- float *dst_dd, queue_ptr stream) {
1939
-
1940
- GGML_TENSOR_BINARY_OP_LOCALS
1941
-
1942
- const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
1943
- const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
1944
- const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
1945
-
1946
- // strides in elements
1947
- //const size_t s0 = nb0 / ggml_element_size(dst);
1948
- const size_t s1 = nb1 / ggml_element_size(dst);
1949
- const size_t s2 = nb2 / ggml_element_size(dst);
1950
- const size_t s3 = nb3 / ggml_element_size(dst);
1951
-
1952
- const size_t s10 = nb10 / ggml_element_size(src1);
1953
- const size_t s11 = nb11 / ggml_element_size(src1);
1954
- const size_t s12 = nb12 / ggml_element_size(src1);
1955
- //const size_t s13 = nb13 / ggml_element_size(src1);
1956
-
1957
- {
1958
- dpct::has_capability_or_fail(stream->get_device(),
1959
- {sycl::aspect::fp16});
1960
-
1961
- stream->parallel_for(
1962
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1963
- [=](sycl::nd_item<3> item_ct1) {
1964
- k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
1965
- s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
1966
- });
1967
- }
1968
-
1969
- GGML_UNUSED(dst);
1970
- GGML_UNUSED(ctx);
1971
- }
1972
-
1973
1674
  static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
1974
1675
  const int ky, const int kx_padded,
1975
1676
  queue_ptr stream) {
@@ -1984,7 +1685,7 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
1984
1685
 
1985
1686
  stream->parallel_for(
1986
1687
  sycl::nd_range<3>(num_blocks * block_size, block_size),
1987
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
1688
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1988
1689
  quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1989
1690
  });
1990
1691
  }
@@ -2005,7 +1706,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
2005
1706
 
2006
1707
  stream->parallel_for(
2007
1708
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
2008
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
1709
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
2009
1710
  mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
2010
1711
  nchannels_y, item_ct1);
2011
1712
  });
@@ -2025,7 +1726,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
2025
1726
 
2026
1727
  stream->parallel_for(
2027
1728
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
2028
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
1729
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
2029
1730
  mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
2030
1731
  row_stride_x, channel_stride_x,
2031
1732
  nchannels_y / nchannels_x, item_ct1);
@@ -2033,231 +1734,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
2033
1734
  }
2034
1735
  }
2035
1736
 
2036
- static void
2037
- ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00,
2038
- const int ne01, const int ne02, const int nb00,
2039
- const int nb01, const int nb02, const int nb03,
2040
- const int ne10, const int ne11, const int ne12,
2041
- const int nb10, const int nb11, const int nb12,
2042
- const int nb13, queue_ptr stream) {
2043
-
2044
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2045
- {
2046
- dpct::has_capability_or_fail(stream->get_device(),
2047
- {sycl::aspect::fp16});
2048
-
2049
- stream->parallel_for(
2050
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2051
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2052
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2053
- [=](sycl::nd_item<3> item_ct1) {
2054
- cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00,
2055
- nb01, nb02, nb03, ne10, ne11, ne12,
2056
- nb10, nb11, nb12, nb13, item_ct1);
2057
- });
2058
- }
2059
- }
2060
-
2061
- static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
2062
- const int ne00, const int ne01,
2063
- const int ne02, const int nb00,
2064
- const int nb01, const int nb02,
2065
- const int nb03, const int ne10,
2066
- const int ne11, const int ne12,
2067
- const int nb10, const int nb11,
2068
- const int nb12, const int nb13,
2069
- queue_ptr stream) {
2070
-
2071
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2072
- {
2073
- dpct::has_capability_or_fail(stream->get_device(),
2074
- {sycl::aspect::fp16});
2075
-
2076
- stream->parallel_for(
2077
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2078
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2079
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2080
- [=](sycl::nd_item<3> item_ct1) {
2081
- cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2082
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2083
- item_ct1);
2084
- });
2085
- }
2086
- }
2087
-
2088
- static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne,
2089
- const int ne00, const int ne01,
2090
- const int ne02, const int nb00,
2091
- const int nb01, const int nb02,
2092
- const int nb03, const int ne10,
2093
- const int ne11, const int ne12,
2094
- const int nb10, const int nb11,
2095
- const int nb12, const int nb13,
2096
- queue_ptr stream) {
2097
-
2098
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2099
- {
2100
- dpct::has_capability_or_fail(stream->get_device(),
2101
- {sycl::aspect::fp16});
2102
-
2103
- stream->parallel_for(
2104
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2105
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2106
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2107
- [=](sycl::nd_item<3> item_ct1) {
2108
- cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2109
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2110
- item_ct1);
2111
- });
2112
- }
2113
- }
2114
-
2115
- static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne,
2116
- const int ne00, const int ne01,
2117
- const int ne02, const int nb00,
2118
- const int nb01, const int nb02,
2119
- const int nb03, const int ne10,
2120
- const int ne11, const int ne12,
2121
- const int nb10, const int nb11,
2122
- const int nb12, const int nb13,
2123
- queue_ptr stream) {
2124
-
2125
- GGML_ASSERT(ne % QK8_0 == 0);
2126
- const int num_blocks = ne / QK8_0;
2127
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2128
- sycl::range<3>(1, 1, 1)),
2129
- [=](sycl::nd_item<3> item_ct1) {
2130
- cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(
2131
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2132
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2133
- item_ct1);
2134
- });
2135
- }
2136
-
2137
- static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne,
2138
- const int ne00, const int ne01,
2139
- const int ne02, const int nb00,
2140
- const int nb01, const int nb02,
2141
- const int nb03, const int ne10,
2142
- const int ne11, const int ne12,
2143
- const int nb10, const int nb11,
2144
- const int nb12, const int nb13,
2145
- queue_ptr stream) {
2146
-
2147
- GGML_ASSERT(ne % QK4_0 == 0);
2148
- const int num_blocks = ne / QK4_0;
2149
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2150
- sycl::range<3>(1, 1, 1)),
2151
- [=](sycl::nd_item<3> item_ct1) {
2152
- cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(
2153
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2154
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2155
- item_ct1);
2156
- });
2157
- }
2158
-
2159
- static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne,
2160
- const int ne00, const int ne01,
2161
- const int ne02, const int nb00,
2162
- const int nb01, const int nb02,
2163
- const int nb03, const int ne10,
2164
- const int ne11, const int ne12,
2165
- const int nb10, const int nb11,
2166
- const int nb12, const int nb13,
2167
- queue_ptr stream) {
2168
-
2169
- GGML_ASSERT(ne % QK4_1 == 0);
2170
- const int num_blocks = ne / QK4_1;
2171
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2172
- sycl::range<3>(1, 1, 1)),
2173
- [=](sycl::nd_item<3> item_ct1) {
2174
- cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(
2175
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2176
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2177
- item_ct1);
2178
- });
2179
- }
2180
-
2181
- static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne,
2182
- const int ne00, const int ne01,
2183
- const int ne02, const int nb00,
2184
- const int nb01, const int nb02,
2185
- const int nb03, const int ne10,
2186
- const int ne11, const int ne12,
2187
- const int nb10, const int nb11,
2188
- const int nb12, const int nb13,
2189
- queue_ptr stream) {
2190
-
2191
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2192
- {
2193
- dpct::has_capability_or_fail(stream->get_device(),
2194
- {sycl::aspect::fp16});
2195
-
2196
- stream->parallel_for(
2197
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2198
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2199
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2200
- [=](sycl::nd_item<3> item_ct1) {
2201
- cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2202
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2203
- item_ct1);
2204
- });
2205
- }
2206
- }
2207
-
2208
- static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne,
2209
- const int ne00, const int ne01,
2210
- const int ne02, const int nb00,
2211
- const int nb01, const int nb02,
2212
- const int nb03, const int ne10,
2213
- const int ne11, const int ne12,
2214
- const int nb10, const int nb11,
2215
- const int nb12, const int nb13,
2216
- queue_ptr stream) {
2217
-
2218
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2219
- {
2220
- // dpct::has_capability_or_fail(stream->get_device(),
2221
- // {sycl::aspect::fp16});
2222
-
2223
- stream->parallel_for(
2224
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2225
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2226
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2227
- [=](sycl::nd_item<3> item_ct1) {
2228
- cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2229
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2230
- item_ct1);
2231
- });
2232
- }
2233
- }
2234
-
2235
- static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne,
2236
- const int ne00, const int ne01,
2237
- const int ne02, const int nb00,
2238
- const int nb01, const int nb02,
2239
- const int nb03, const int ne10,
2240
- const int ne11, const int ne12,
2241
- const int nb10, const int nb11,
2242
- const int nb12, const int nb13,
2243
- queue_ptr stream) {
2244
1737
 
2245
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2246
- {
2247
- // dpct::has_capability_or_fail(stream->get_device(),
2248
- // {sycl::aspect::fp16});
2249
-
2250
- stream->parallel_for(
2251
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2252
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2253
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2254
- [=](sycl::nd_item<3> item_ct1) {
2255
- cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2256
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2257
- item_ct1);
2258
- });
2259
- }
2260
- }
2261
1738
 
2262
1739
  static void scale_f32_sycl(const float *x, float *dst, const float scale,
2263
1740
  const int k, queue_ptr stream) {
@@ -2290,7 +1767,7 @@ static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
2290
1767
  const sycl::range<3> block_nums(1, nrows, 1);
2291
1768
  stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
2292
1769
  [=](sycl::nd_item<3> item_ct1)
2293
- [[intel::reqd_sub_group_size(WARP_SIZE)]] {
1770
+ [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
2294
1771
  k_sum_rows_f32(x, dst, ncols, item_ct1);
2295
1772
  });
2296
1773
  }
@@ -2493,52 +1970,6 @@ catch (sycl::exception const &exc) {
2493
1970
  std::exit(1);
2494
1971
  }
2495
1972
 
2496
- static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2497
- const ggml_tensor *src1, ggml_tensor *dst,
2498
- const float *src0_d, const float *src1_d,
2499
- float *dst_d, const queue_ptr &stream) {
2500
-
2501
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
2502
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
2503
-
2504
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
2505
- GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
2506
- GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
2507
-
2508
- const int32_t * src1_i32 = (const int32_t *) src1_d;
2509
-
2510
- switch (src0->type) {
2511
- case GGML_TYPE_F16:
2512
- get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
2513
- src1_i32, dst_d, stream);
2514
- break;
2515
- case GGML_TYPE_F32:
2516
- get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2517
- break;
2518
- case GGML_TYPE_Q4_0:
2519
- get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2520
- break;
2521
- case GGML_TYPE_Q4_1:
2522
- get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2523
- break;
2524
- case GGML_TYPE_Q5_0:
2525
- get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2526
- break;
2527
- case GGML_TYPE_Q5_1:
2528
- get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2529
- break;
2530
- case GGML_TYPE_Q8_0:
2531
- get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2532
- break;
2533
- default:
2534
- // TODO: k-quants
2535
- GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
2536
- GGML_ABORT("fatal error");
2537
- break;
2538
- }
2539
- }
2540
-
2541
-
2542
1973
  static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2543
1974
  const ggml_tensor *src1, ggml_tensor *dst,
2544
1975
  const float *src0_d, const float *src1_d,
@@ -2588,11 +2019,10 @@ inline void ggml_sycl_op_mul_mat_sycl(
2588
2019
  if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
2589
2020
  use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] &&
2590
2021
  dst->op_params[0] == GGML_PREC_DEFAULT) {
2591
-
2592
2022
  // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n");
2593
2023
  ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
2594
2024
  if (src0->type != GGML_TYPE_F16) {
2595
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type);
2025
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst);
2596
2026
  GGML_ASSERT(to_fp16_sycl != nullptr);
2597
2027
  size_t ne = row_diff*ne00;
2598
2028
  src0_as_f16.alloc(ne);
@@ -2604,7 +2034,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
2604
2034
 
2605
2035
  ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());
2606
2036
  if (src1->type != GGML_TYPE_F16) {
2607
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
2037
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
2608
2038
  GGML_ASSERT(to_fp16_sycl != nullptr);
2609
2039
  size_t ne = src1_ncols*ne10;
2610
2040
  src1_as_f16.alloc(ne);
@@ -2625,13 +2055,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
2625
2055
  src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2626
2056
  dst_f16.get(), dpct::library_data_t::real_half, ldc,
2627
2057
  dpct::library_data_t::real_half)));
2628
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2058
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2629
2059
  to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2630
2060
  #else
2631
2061
  auto dnnl_stream = ctx.stream_dnnl(stream);
2632
2062
  DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2633
2063
  src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
2634
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2064
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2635
2065
  to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2636
2066
  #endif
2637
2067
  }
@@ -2640,13 +2070,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
2640
2070
  ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());
2641
2071
  ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());
2642
2072
  if (src0->type != GGML_TYPE_F32) {
2643
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type);
2073
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst);
2644
2074
  GGML_ASSERT(to_fp32_sycl != nullptr);
2645
2075
  src0_ddq_as_f32.alloc(row_diff*ne00);
2646
2076
  to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
2647
2077
  }
2648
2078
  if (src1->type != GGML_TYPE_F32) {
2649
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type);
2079
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst);
2650
2080
  GGML_ASSERT(to_fp32_sycl != nullptr);
2651
2081
  src1_ddq_as_f32.alloc(src1_ncols*ne10);
2652
2082
  to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
@@ -3084,7 +2514,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
3084
2514
  for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
3085
2515
  const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
3086
2516
  const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
3087
-
3088
2517
  for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
3089
2518
  if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
3090
2519
  continue;
@@ -3270,6 +2699,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
3270
2699
  GGML_SYCL_DEBUG("call %s done\n", __func__);
3271
2700
  }
3272
2701
 
2702
+ static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2703
+ GGML_SYCL_DEBUG("call %s\n", __func__);
2704
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
2705
+ GGML_SYCL_DEBUG("call %s done\n", __func__);
2706
+ }
2707
+
3273
2708
  static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3274
2709
  GGML_SYCL_DEBUG("call %s\n", __func__);
3275
2710
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
@@ -3392,7 +2827,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3392
2827
  // convert src1 to fp16
3393
2828
  ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
3394
2829
  if (src1->type != GGML_TYPE_F16) {
3395
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
2830
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
3396
2831
  const int64_t ne_src1 = ggml_nelements(src1);
3397
2832
  src1_f16_alloc.alloc(ne_src1);
3398
2833
  GGML_ASSERT(to_fp16_sycl != nullptr);
@@ -3488,7 +2923,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
3488
2923
  return false;
3489
2924
  }
3490
2925
 
3491
- bool ggml_sycl_supports_dmmv(enum ggml_type type) {
2926
+ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
3492
2927
  switch (type) {
3493
2928
  case GGML_TYPE_Q4_0:
3494
2929
  case GGML_TYPE_Q4_1:
@@ -3508,6 +2943,7 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
3508
2943
  }
3509
2944
 
3510
2945
  static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2946
+
3511
2947
  const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
3512
2948
  int64_t min_compute_capability = INT_MAX;
3513
2949
 
@@ -3569,6 +3005,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3569
3005
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3570
3006
  } else if (use_dequantize_mul_mat_vec) {
3571
3007
  ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
3008
+ // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
3572
3009
  } else if (use_mul_mat_vec_q) {
3573
3010
  ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
3574
3011
  } else if (use_mul_mat_q) {
@@ -3701,8 +3138,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3701
3138
  const int64_t i2 = i12;
3702
3139
 
3703
3140
  src0_row.data = src0_original + i02*nb02;
3704
- src1_row.data = src1_original + + i11*nb11 + i12*nb12;
3705
- dst_row.data = dst_original + i1*nb1 + i2*nb2;
3141
+ src1_row.data = src1_original + i11*nb11 + i12*nb12;
3142
+ dst_row.data = dst_original + i1*nb1 + i2*nb2;
3706
3143
 
3707
3144
  ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
3708
3145
  }
@@ -3821,58 +3258,6 @@ static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
3821
3258
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp);
3822
3259
  }
3823
3260
 
3824
- static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3825
- ggml_tensor *dst) try {
3826
- const int64_t ne = ggml_nelements(src0);
3827
- GGML_ASSERT(ne == ggml_nelements(src1));
3828
-
3829
- GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
3830
- GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
3831
-
3832
- GGML_TENSOR_BINARY_OP_LOCALS01;
3833
-
3834
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
3835
- queue_ptr main_stream = ctx.stream();
3836
-
3837
- char * src0_ddc = (char *) src0->data;
3838
- char * src1_ddc = (char *) src1->data;
3839
-
3840
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
3841
- ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3842
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
3843
- ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3844
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
3845
- ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3846
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
3847
- ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3848
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
3849
- ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3850
- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
3851
- ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3852
- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
3853
- ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3854
- } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
3855
- ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3856
- } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
3857
- ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3858
- } else {
3859
- GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__,
3860
- ggml_type_name(src0->type), ggml_type_name(src1->type));
3861
- GGML_ABORT("fatal error");
3862
- }
3863
- GGML_UNUSED(dst);
3864
- }
3865
- catch (sycl::exception const &exc) {
3866
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3867
- << ", line:" << __LINE__ << std::endl;
3868
- std::exit(1);
3869
- }
3870
-
3871
- static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3872
- // TODO: why do we pass dst as src1 here?
3873
- ggml_sycl_cpy(ctx, dst->src[0], dst, nullptr);
3874
- }
3875
-
3876
3261
  static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3877
3262
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
3878
3263
  }
@@ -3911,7 +3296,7 @@ static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
3911
3296
  }
3912
3297
 
3913
3298
 
3914
- void ggml_sycl_set_main_device(const int main_device) try {
3299
+ static void ggml_sycl_set_main_device(const int main_device) try {
3915
3300
  if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
3916
3301
  return;
3917
3302
  }
@@ -3932,7 +3317,7 @@ catch (sycl::exception const &exc) {
3932
3317
  std::exit(1);
3933
3318
  }
3934
3319
 
3935
- bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
3320
+ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
3936
3321
  if (!g_sycl_loaded) return false;
3937
3322
 
3938
3323
  if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
@@ -4034,6 +3419,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
4034
3419
  case GGML_OP_RMS_NORM:
4035
3420
  ggml_sycl_rms_norm(ctx, dst);
4036
3421
  break;
3422
+ case GGML_OP_L2_NORM:
3423
+ ggml_sycl_l2_norm(ctx, dst);
3424
+ break;
4037
3425
  case GGML_OP_MUL_MAT:
4038
3426
  if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
4039
3427
  return false;
@@ -4069,7 +3457,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
4069
3457
  ggml_sycl_clamp(ctx, dst);
4070
3458
  break;
4071
3459
  case GGML_OP_CPY:
4072
- ggml_sycl_cpy(ctx, dst->src[0], dst->src[1], dst);
3460
+ ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]);
4073
3461
  break;
4074
3462
  case GGML_OP_CONT:
4075
3463
  ggml_sycl_dup(ctx, dst);
@@ -4111,6 +3499,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
4111
3499
  case GGML_OP_RWKV_WKV6:
4112
3500
  ggml_sycl_op_rwkv_wkv6(ctx, dst);
4113
3501
  break;
3502
+ case GGML_OP_RWKV_WKV7:
3503
+ ggml_sycl_op_rwkv_wkv7(ctx, dst);
3504
+ break;
4114
3505
  case GGML_OP_GATED_LINEAR_ATTN:
4115
3506
  ggml_sycl_op_gated_linear_attn(ctx, dst);
4116
3507
  break;
@@ -4250,10 +3641,71 @@ catch (sycl::exception const &exc) {
4250
3641
  std::exit(1);
4251
3642
  }
4252
3643
 
4253
- static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
4254
- ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4255
- ggml_sycl_set_main_device(sycl_ctx->device);
3644
+ static void reorder_qw(char *data_device, const int ncols, const int nrows,
3645
+ size_t size, size_t offset, dpct::queue_ptr stream) {
3646
+ auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
3647
+ SYCL_CHECK(
3648
+ CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
3649
+ .wait()));
3650
+ GGML_ASSERT((size % sizeof(block_q4_0) == 0));
3651
+ GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
3652
+ int offset_blks = offset / sizeof(block_q4_0);
3653
+ auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;;
3654
+ auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
3655
+
3656
+ stream->parallel_for(
3657
+ size / sizeof(block_q4_0),
3658
+ [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
3659
+ const block_q4_0* x = (const block_q4_0*)tmp_buf;
3660
+ const int ib = i;
3661
+
3662
+ for (int j = 0; j < QK4_0/2; j ++)
3663
+ {
3664
+ *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
3665
+ }
3666
+ *(d_ptr + ib) = x[ib].d;
3667
+ });
3668
+
3669
+ sycl::free(tmp_buf, *stream);
3670
+ }
3671
+
3672
+ static void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
3673
+ char*data_device = (char*)src0->data;
3674
+ size_t ncols = src0->ne[0];
3675
+ size_t nrows = src0->ne[1];
3676
+ size_t size = ggml_nbytes(src0);
3677
+
3678
+ reorder_qw(data_device, ncols, nrows, size, 0, stream);
3679
+ }
4256
3680
 
3681
+ static void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
3682
+ ggml_tensor *src0 = dst->src[0];
3683
+ ggml_tensor *src1 = dst->src[1];
3684
+
3685
+ if (dst->op == GGML_OP_MUL_MAT && src0->type == GGML_TYPE_Q4_0 &&
3686
+ src1->ne[2]==1 && src1->ne[3]==1) {
3687
+ reorder_qw(src0, stream);
3688
+ ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
3689
+ GGML_ASSERT(extra);
3690
+ extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.
3691
+ }
3692
+ }
3693
+
3694
+ static void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
3695
+ dpct::queue_ptr stream = ctx->stream();
3696
+ if (ctx->optimized_graph) {
3697
+ return;
3698
+ }
3699
+ ctx->optimized_graph = true;
3700
+
3701
+ for (int i = 0; i < cgraph->n_nodes; i++) {
3702
+ if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream);
3703
+ }
3704
+ }
3705
+
3706
+ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) {
3707
+ ggml_sycl_set_main_device(sycl_ctx->device);
3708
+ if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx);
4257
3709
 
4258
3710
  for (int i = 0; i < cgraph->n_nodes; i++) {
4259
3711
  ggml_tensor * node = cgraph->nodes[i];
@@ -4274,7 +3726,46 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
4274
3726
  }
4275
3727
  GGML_ASSERT(ok);
4276
3728
  }
3729
+ }
4277
3730
 
3731
+ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
3732
+ auto * sycl_ctx = static_cast<ggml_backend_sycl_context *>(backend->context);
3733
+
3734
+ #ifdef GGML_SYCL_GRAPH
3735
+ if (!g_ggml_sycl_disable_graph) {
3736
+ if (!sycl_ctx->exec_graph && !dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_graph)) {
3737
+ GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device);
3738
+ ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
3739
+ return GGML_STATUS_SUCCESS;
3740
+ }
3741
+
3742
+ sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
3743
+ model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
3744
+ ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
3745
+ model_sycl_graph.end_recording();
3746
+
3747
+ if (!sycl_ctx->exec_graph) {
3748
+ auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
3749
+ sycl_ctx->exec_graph = std::make_unique<
3750
+ sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
3751
+ } else {
3752
+ try {
3753
+ sycl_ctx->exec_graph->update(model_sycl_graph);
3754
+ GGML_SYCL_DEBUG("[SYCL-GRAPH] update success\n");
3755
+ } catch (sycl::exception const & e) {
3756
+ GGML_SYCL_DEBUG("[SYCL-GRAPH] Exception when updating graph, %s\n", e.what());
3757
+ auto exec_graph = model_sycl_graph.finalize({sycl_ex::property::graph::updatable{}});
3758
+ sycl_ctx->exec_graph = std::make_unique<
3759
+ sycl_ex::command_graph<sycl_ex::graph_state::executable>>(exec_graph);
3760
+ }
3761
+ }
3762
+
3763
+ sycl_ctx->stream()->ext_oneapi_graph(*(sycl_ctx->exec_graph));
3764
+ } else
3765
+ #endif
3766
+ {
3767
+ ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
3768
+ }
4278
3769
  return GGML_STATUS_SUCCESS;
4279
3770
  }
4280
3771
 
@@ -4339,7 +3830,6 @@ bool ggml_backend_is_sycl(ggml_backend_t backend) {
4339
3830
  }
4340
3831
 
4341
3832
  int ggml_backend_sycl_get_device_count() {
4342
- GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
4343
3833
  return ggml_sycl_info().device_count;
4344
3834
  }
4345
3835
 
@@ -4429,7 +3919,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4429
3919
  return true;
4430
3920
  }
4431
3921
  return false;
4432
- } break;
3922
+ }
4433
3923
  case GGML_OP_UNARY:
4434
3924
  switch (ggml_get_unary_op(op)) {
4435
3925
  case GGML_UNARY_OP_NEG:
@@ -4443,11 +3933,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4443
3933
  case GGML_UNARY_OP_GELU_QUICK:
4444
3934
  case GGML_UNARY_OP_TANH:
4445
3935
  case GGML_UNARY_OP_EXP:
4446
- return ggml_is_contiguous(op->src[0]);
3936
+ return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32);
4447
3937
  default:
4448
3938
  return false;
4449
3939
  }
4450
- break;
4451
3940
  case GGML_OP_MUL_MAT:
4452
3941
  case GGML_OP_MUL_MAT_ID:
4453
3942
  {
@@ -4478,7 +3967,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4478
3967
  return false;
4479
3968
  }
4480
3969
  return true;
4481
- } break;
3970
+ }
4482
3971
  case GGML_OP_OUT_PROD:
4483
3972
  return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
4484
3973
  case GGML_OP_GET_ROWS:
@@ -4495,7 +3984,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4495
3984
  default:
4496
3985
  return false;
4497
3986
  }
4498
- } break;
3987
+ }
4499
3988
  case GGML_OP_CPY:
4500
3989
  {
4501
3990
  ggml_type src0_type = op->src[0]->type;
@@ -4521,13 +4010,37 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4521
4010
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
4522
4011
  return true;
4523
4012
  }
4013
+ if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
4014
+ return true;
4015
+ }
4016
+ if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
4017
+ return true;
4018
+ }
4019
+ if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
4020
+ return true;
4021
+ }
4022
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
4023
+ return true;
4024
+ }
4025
+ if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
4026
+ return true;
4027
+ }
4028
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
4029
+ return true;
4030
+ }
4031
+ if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
4032
+ return true;
4033
+ }
4034
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
4035
+ return true;
4036
+ }
4524
4037
  return false;
4525
- } break;
4038
+ }
4526
4039
  case GGML_OP_CONCAT:
4527
4040
  {
4528
4041
  ggml_type src0_type = op->src[0]->type;
4529
4042
  return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
4530
- } break;
4043
+ }
4531
4044
  case GGML_OP_DUP:
4532
4045
  case GGML_OP_ARGMAX:
4533
4046
  case GGML_OP_NONE:
@@ -4536,23 +4049,25 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4536
4049
  case GGML_OP_VIEW:
4537
4050
  case GGML_OP_PERMUTE:
4538
4051
  case GGML_OP_TRANSPOSE:
4052
+ return true;
4539
4053
  case GGML_OP_ADD:
4540
4054
  case GGML_OP_ADD1:
4541
- case GGML_OP_LOG:
4542
4055
  case GGML_OP_SUB:
4543
4056
  case GGML_OP_MUL:
4544
4057
  case GGML_OP_DIV:
4545
- return true;
4546
- case GGML_OP_NORM:
4547
- case GGML_OP_RMS_NORM:
4548
- case GGML_OP_GROUP_NORM:
4549
- return ggml_is_contiguous(op->src[0]);
4550
- case GGML_OP_SCALE:
4551
4058
  case GGML_OP_SQR:
4552
4059
  case GGML_OP_SQRT:
4553
4060
  case GGML_OP_SIN:
4554
4061
  case GGML_OP_COS:
4555
4062
  case GGML_OP_CLAMP:
4063
+ case GGML_OP_LOG:
4064
+ return (op->src[0]->type == GGML_TYPE_F32);
4065
+ case GGML_OP_NORM:
4066
+ case GGML_OP_RMS_NORM:
4067
+ case GGML_OP_L2_NORM:
4068
+ case GGML_OP_GROUP_NORM:
4069
+ return ggml_is_contiguous(op->src[0]);
4070
+ case GGML_OP_SCALE:
4556
4071
  return true;
4557
4072
  case GGML_OP_CONT:
4558
4073
  return op->src[0]->type != GGML_TYPE_BF16;
@@ -4583,6 +4098,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4583
4098
  case GGML_OP_LEAKY_RELU:
4584
4099
  case GGML_OP_TIMESTEP_EMBEDDING:
4585
4100
  case GGML_OP_RWKV_WKV6:
4101
+ case GGML_OP_RWKV_WKV7:
4586
4102
  case GGML_OP_GATED_LINEAR_ATTN:
4587
4103
  return true;
4588
4104
  default: