@fugood/llama.node 0.3.12 → 0.3.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +2 -1
  18. package/package.json +1 -1
  19. package/src/LlamaCompletionWorker.cpp +14 -0
  20. package/src/LlamaContext.cpp +110 -79
  21. package/src/LlamaContext.h +1 -1
  22. package/src/common.hpp +1 -2
  23. package/src/llama.cpp/.github/workflows/build.yml +95 -13
  24. package/src/llama.cpp/.github/workflows/docker.yml +2 -0
  25. package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  27. package/src/llama.cpp/common/CMakeLists.txt +23 -6
  28. package/src/llama.cpp/common/arg.cpp +292 -14
  29. package/src/llama.cpp/common/chat.cpp +1128 -315
  30. package/src/llama.cpp/common/chat.h +135 -0
  31. package/src/llama.cpp/common/common.cpp +27 -171
  32. package/src/llama.cpp/common/common.h +41 -73
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  34. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  35. package/src/llama.cpp/common/llguidance.cpp +3 -3
  36. package/src/llama.cpp/common/log.cpp +1 -0
  37. package/src/llama.cpp/common/log.h +2 -1
  38. package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +21 -7
  39. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +61 -14
  40. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  41. package/src/llama.cpp/common/sampling.cpp +93 -49
  42. package/src/llama.cpp/common/speculative.cpp +6 -5
  43. package/src/llama.cpp/common/speculative.h +1 -1
  44. package/src/llama.cpp/docs/build.md +47 -9
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
  47. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  48. package/src/llama.cpp/examples/imatrix/imatrix.cpp +4 -4
  49. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
  50. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
  52. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  53. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  54. package/src/llama.cpp/examples/llava/clip.h +19 -3
  55. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  56. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  57. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  58. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
  59. package/src/llama.cpp/examples/main/main.cpp +73 -28
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  64. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  65. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  66. package/src/llama.cpp/examples/run/run.cpp +115 -79
  67. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/server/httplib.h +381 -292
  69. package/src/llama.cpp/examples/server/server.cpp +134 -128
  70. package/src/llama.cpp/examples/server/utils.hpp +95 -106
  71. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  72. package/src/llama.cpp/examples/tts/tts.cpp +251 -142
  73. package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
  74. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  75. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  76. package/src/llama.cpp/ggml/include/ggml-cpu.h +4 -1
  77. package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
  78. package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
  79. package/src/llama.cpp/ggml/include/ggml.h +6 -2
  80. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  81. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  82. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  83. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  84. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  85. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
  86. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  87. package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
  88. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
  89. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  90. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +156 -11
  93. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +2235 -641
  94. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1572 -198
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +24 -5
  96. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  97. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
  101. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +16 -3
  102. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  103. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  104. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  105. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
  106. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
  107. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  108. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +246 -120
  109. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  110. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  111. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
  112. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  114. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
  115. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  116. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  117. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  120. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
  121. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +174 -728
  124. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
  125. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
  126. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  127. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  128. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +949 -602
  129. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +37 -3
  130. package/src/llama.cpp/ggml/src/ggml.c +9 -4
  131. package/src/llama.cpp/include/llama.h +32 -14
  132. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  133. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  134. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  135. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  136. package/src/llama.cpp/requirements.txt +1 -0
  137. package/src/llama.cpp/src/llama-arch.cpp +21 -0
  138. package/src/llama.cpp/src/llama-arch.h +1 -0
  139. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  140. package/src/llama.cpp/src/llama-grammar.cpp +183 -183
  141. package/src/llama.cpp/src/llama-grammar.h +13 -4
  142. package/src/llama.cpp/src/llama-impl.h +6 -6
  143. package/src/llama.cpp/src/llama-kv-cache.h +2 -1
  144. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  145. package/src/llama.cpp/src/llama-mmap.h +1 -0
  146. package/src/llama.cpp/src/llama-model.cpp +70 -6
  147. package/src/llama.cpp/src/llama-sampling.cpp +174 -67
  148. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  149. package/src/llama.cpp/src/llama.cpp +154 -5
  150. package/src/llama.cpp/src/unicode.cpp +9 -2
  151. package/src/llama.cpp/tests/test-backend-ops.cpp +171 -115
  152. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  153. package/src/llama.cpp/tests/test-chat.cpp +691 -325
  154. package/src/llama.cpp/tests/test-gguf.cpp +4 -4
  155. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  156. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  157. package/src/llama.cpp/tests/test-sampling.cpp +15 -0
  158. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  159. package/src/llama.cpp/common/chat.hpp +0 -52
@@ -39,8 +39,13 @@
39
39
  #include "ggml-sycl/backend.hpp"
40
40
  #include "ggml-sycl/presets.hpp"
41
41
  #include "ggml-sycl/gemm.hpp"
42
+ #include "ggml-sycl/sycl_hw.hpp"
43
+ #include "ggml-sycl/getrows.hpp"
44
+ #include "ggml.h"
42
45
 
43
46
  static bool g_sycl_loaded = false;
47
+ int g_ggml_sycl_debug = 0;
48
+ int g_ggml_sycl_disable_optimize = 0;
44
49
 
45
50
  static ggml_sycl_device_info ggml_sycl_init() {
46
51
  ggml_sycl_device_info info = {};
@@ -63,14 +68,18 @@ static ggml_sycl_device_info ggml_sycl_init() {
63
68
  for (int i = 0; i < info.device_count; ++i) {
64
69
  info.devices[i].vmm = 0;
65
70
  dpct::device_info prop;
71
+ sycl::device device = dpct::dev_mgr::instance().get_device(i);
72
+
66
73
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
67
- prop, dpct::dev_mgr::instance().get_device(i))));
74
+ prop, device)));
68
75
 
69
76
  info.default_tensor_split[i] = total_vram;
70
77
  total_vram += prop.get_global_mem_size();
71
78
 
72
79
  info.devices[i].cc =
73
80
  100 * prop.get_major_version() + 10 * prop.get_minor_version();
81
+ info.devices[i].hw_info = get_device_hw_info(&device);
82
+ info.devices[i].opt_feature = check_gpu_optimize_feature(info.devices[i].hw_info.arch);
74
83
 
75
84
  info.max_work_group_sizes[i] = prop.get_max_work_group_size();
76
85
  }
@@ -103,13 +112,33 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
103
112
  name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
104
113
 
105
114
  auto global_mem_size = prop.get_global_mem_size()/1000000;
106
- std::string xmx = gpu_has_xmx(device) ? "yes" : "no";
107
- GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|%14s|\n", id, device_type.c_str(),
115
+ GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
108
116
  name.c_str(), version.c_str(), prop.get_max_compute_units(),
109
117
  prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
110
- global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str(), xmx.c_str());
118
+ global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
111
119
  }
112
120
 
121
+ void print_device_opt_feature(int device_count) {
122
+ GGML_LOG_INFO("SYCL Optimization Feature:\n");
123
+ GGML_LOG_INFO(
124
+ "|ID| Device Type|Reorder|\n");
125
+ GGML_LOG_INFO(
126
+ "|--|-------------------|-------|\n");
127
+ std::map<std::string, size_t> DeviceNums;
128
+ for (int id = 0; id < device_count; ++id) {
129
+ sycl::device device = dpct::dev_mgr::instance().get_device(id);
130
+ std::string backend_type = get_device_backend_and_type(device);
131
+ int type_id = DeviceNums[backend_type]++;
132
+ std::stringstream device_type;
133
+ device_type << "[" << backend_type << ":" << std::to_string(type_id)
134
+ << "]";
135
+ std::string device_type_s = device_type.str();
136
+ device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), "");
137
+ GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(),
138
+ ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N");
139
+ }
140
+
141
+ }
113
142
  void ggml_backend_sycl_print_sycl_devices() {
114
143
  GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
115
144
  int device_count = dpct::dev_mgr::instance().device_count();
@@ -118,16 +147,16 @@ void ggml_backend_sycl_print_sycl_devices() {
118
147
 
119
148
  GGML_LOG_INFO(
120
149
  "| | | | "
121
- " |Max | |Max |Global | | XMX |\n");
150
+ " |Max | |Max |Global | |\n");
122
151
  GGML_LOG_INFO(
123
152
  "| | | | "
124
- " |compute|Max work|sub |mem | | or |\n");
153
+ " |compute|Max work|sub |mem | |\n");
125
154
  GGML_LOG_INFO(
126
155
  "|ID| Device Type| "
127
- "Name|Version|units |group |group|size | Driver version| Tensor Cores |\n");
156
+ "Name|Version|units |group |group|size | Driver version|\n");
128
157
  GGML_LOG_INFO(
129
158
  "|--|-------------------|---------------------------------------|------"
130
- "-|-------|--------|-----|-------|---------------------|--------------|\n");
159
+ "-|-------|--------|-----|-------|---------------------|\n");
131
160
 
132
161
  for (int id = 0; id < device_count; ++id) {
133
162
  sycl::device device = dpct::dev_mgr::instance().get_device(id);
@@ -138,6 +167,8 @@ void ggml_backend_sycl_print_sycl_devices() {
138
167
  << "]";
139
168
  print_device_detail(id, device, device_type.str());
140
169
  }
170
+
171
+ print_device_opt_feature(device_count);
141
172
  }
142
173
 
143
174
  static inline int get_sycl_env(const char *env_name, int default_val) {
@@ -158,18 +189,22 @@ static void ggml_check_sycl() try {
158
189
  static bool initialized = false;
159
190
 
160
191
  if (!initialized) {
161
- GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
162
192
  g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
163
- GGML_LOG_INFO("GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
193
+ g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
194
+ GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
195
+ GGML_LOG_INFO("Running with Environment Variables:\n");
196
+ GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
197
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
198
+ GGML_LOG_INFO("Build with Macros:\n");
164
199
  #if defined(GGML_SYCL_FORCE_MMQ)
165
- GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: yes\n");
200
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
166
201
  #else
167
- GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: no\n");
202
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
168
203
  #endif
169
204
  #if defined(GGML_SYCL_F16)
170
- GGML_LOG_INFO("GGML_SYCL_F16: yes\n");
205
+ GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
171
206
  #else
172
- GGML_LOG_INFO("GGML_SYCL_F16: no\n");
207
+ GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
173
208
  #endif
174
209
 
175
210
  /* NOT REMOVE, keep it for next optimize for XMX.
@@ -241,19 +276,27 @@ struct ggml_backend_sycl_buffer_context {
241
276
  void * dev_ptr = nullptr;
242
277
  queue_ptr stream;
243
278
  std::string name;
279
+ optimize_feature opt_feature;
280
+ std::vector<ggml_tensor_extra_gpu *> tensor_extras;
244
281
 
245
- ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
282
+ ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
246
283
  device(device), dev_ptr(dev_ptr), stream(stream) {
247
284
  check_allow_gpu_index(device);
248
285
  name = (GGML_SYCL_NAME + std::to_string(device));
286
+ opt_feature = ggml_sycl_info().devices[device].opt_feature;
249
287
  }
250
288
 
251
-
252
289
  ~ggml_backend_sycl_buffer_context() {
253
290
  if (dev_ptr != nullptr) {
254
291
  ggml_sycl_set_device(device);
255
292
  SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
256
293
  }
294
+
295
+ //release extra used by tensors
296
+ for (ggml_tensor_extra_gpu * extra : tensor_extras) {
297
+ release_extra_gpu(extra);
298
+ }
299
+
257
300
  }
258
301
  };
259
302
 
@@ -281,16 +324,19 @@ static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
281
324
  return ctx->dev_ptr;
282
325
  }
283
326
 
284
- static void
327
+ static enum ggml_status
285
328
  ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
286
329
  ggml_tensor *tensor) try {
287
330
  ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
288
331
 
289
332
  if (tensor->view_src != NULL) {
290
333
  assert(tensor->view_src->buffer->buft == buffer->buft);
291
- return;
334
+ return GGML_STATUS_SUCCESS;
292
335
  }
293
336
 
337
+ ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
338
+ tensor->extra = extra;
339
+ ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
294
340
 
295
341
  if (ggml_is_quantized(tensor->type)) {
296
342
  // initialize padding to 0 to avoid possible NaN values
@@ -303,6 +349,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
303
349
  padded_size - original_size).wait()));
304
350
  }
305
351
  }
352
+ return GGML_STATUS_SUCCESS;
306
353
  }
307
354
  catch (sycl::exception const &exc) {
308
355
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -316,7 +363,6 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
316
363
  size_t size) try {
317
364
 
318
365
  ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
319
-
320
366
  ggml_sycl_set_device(ctx->device);
321
367
  auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
322
368
  SYCL_CHECK(
@@ -660,32 +706,7 @@ struct ggml_backend_sycl_split_buffer_type_context {
660
706
  struct ggml_backend_sycl_split_buffer_context {
661
707
  ~ggml_backend_sycl_split_buffer_context() try {
662
708
  for (ggml_tensor_extra_gpu * extra : tensor_extras) {
663
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
664
- for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
665
- if (extra->events[i][is] != nullptr) {
666
- /*
667
- DPCT1009:206: SYCL uses exceptions to report errors and
668
- does not use the error codes. The original code was
669
- commented out and a warning string was inserted. You
670
- need to rewrite this code.
671
- */
672
- SYCL_CHECK(CHECK_TRY_ERROR(
673
- dpct::destroy_event(extra->events[i][is])));
674
- }
675
- }
676
- if (extra->data_device[i] != nullptr) {
677
- /*
678
- DPCT1009:207: SYCL uses exceptions to report errors and does
679
- not use the error codes. The original code was commented out
680
- and a warning string was inserted. You need to rewrite this
681
- code.
682
- */
683
- ggml_sycl_set_device(i);
684
- SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(
685
- extra->data_device[i], *(streams[i]))));
686
- }
687
- }
688
- delete extra;
709
+ release_extra_gpu(extra, streams);
689
710
  }
690
711
  }
691
712
  catch (sycl::exception const &exc) {
@@ -710,7 +731,7 @@ static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buff
710
731
  GGML_UNUSED(buffer);
711
732
  }
712
733
 
713
- static void
734
+ static enum ggml_status
714
735
  ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
715
736
  ggml_tensor *tensor) try {
716
737
  GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
@@ -723,7 +744,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
723
744
  ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
724
745
 
725
746
  ctx->tensor_extras.push_back(extra);
726
- ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
747
+ ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
727
748
 
728
749
  for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
729
750
  int64_t row_low, row_high;
@@ -785,6 +806,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
785
806
  }
786
807
  }
787
808
  tensor->extra = extra;
809
+ return GGML_STATUS_SUCCESS;
788
810
  }
789
811
  catch (sycl::exception const &exc) {
790
812
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -1264,8 +1286,6 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(q
1264
1286
  // struct ggml_sycl_pool_vmm : public ggml_sycl_pool
1265
1287
 
1266
1288
  /// kernels
1267
-
1268
- typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
1269
1289
  typedef void (*ggml_sycl_op_mul_mat_t)(
1270
1290
  ggml_backend_sycl_context & ctx,
1271
1291
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -1337,83 +1357,6 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
1337
1357
  reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
1338
1358
  }
1339
1359
 
1340
- template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
1341
- static void k_get_rows(
1342
- const void * src0, const int32_t * src1, dst_t * dst,
1343
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1344
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1345
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1346
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1347
- size_t s10, size_t s11, size_t s12,
1348
- const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
1349
-
1350
- const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
1351
- item_ct1.get_local_id(2)) *
1352
- 2;
1353
- const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1354
- item_ct1.get_local_id(1);
1355
- const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1356
- item_ct1.get_local_id(0)) /
1357
- ne12;
1358
- const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1359
- item_ct1.get_local_id(0)) %
1360
- ne12;
1361
-
1362
- if (i00 >= ne00) {
1363
- return;
1364
- }
1365
-
1366
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1367
-
1368
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1369
- const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
1370
-
1371
- const int ib = i00/qk; // block index
1372
- const int iqs = (i00%qk)/qr; // quant index
1373
- const int iybs = i00 - i00%qk; // dst block start index
1374
- const int y_offset = qr == 1 ? 1 : qk/2;
1375
-
1376
- // dequantize
1377
- dfloat2 v;
1378
- dequantize_kernel(src0_row, ib, iqs, v);
1379
-
1380
- dst_row[iybs + iqs + 0] = v.x();
1381
- dst_row[iybs + iqs + y_offset] = v.y();
1382
- }
1383
-
1384
- template<typename src0_t, typename dst_t>
1385
- static void k_get_rows_float(
1386
- const src0_t * src0, const int32_t * src1, dst_t * dst,
1387
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1388
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1389
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1390
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1391
- size_t s10, size_t s11, size_t s12,
1392
- const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
1393
-
1394
- const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
1395
- item_ct1.get_local_id(2);
1396
- const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1397
- item_ct1.get_local_id(1);
1398
- const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1399
- item_ct1.get_local_id(0)) /
1400
- ne12;
1401
- const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1402
- item_ct1.get_local_id(0)) %
1403
- ne12;
1404
-
1405
- if (i00 >= ne00) {
1406
- return;
1407
- }
1408
-
1409
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1410
-
1411
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1412
- const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
1413
-
1414
- dst_row[i00] = src0_row[i00];
1415
- }
1416
-
1417
1360
  static void mul_mat_p021_f16_f32(
1418
1361
  const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
1419
1362
  const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1524,193 +1467,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1524
1467
  }
1525
1468
  }
1526
1469
 
1527
- static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
1528
- const float * xi = (const float *) cxi;
1529
- float * dsti = (float *) cdsti;
1530
-
1531
- *dsti = *xi;
1532
- }
1533
-
1534
- static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
1535
- const float * xi = (const float *) cxi;
1536
- sycl::half *dsti = (sycl::half *)cdsti;
1537
-
1538
- *dsti = sycl::vec<float, 1>(*xi)
1539
- .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
1540
- }
1541
-
1542
- static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
1543
- const sycl::half *xi = (const sycl::half *)cxi;
1544
- sycl::half *dsti = (sycl::half *)cdsti;
1545
-
1546
- *dsti = *xi;
1547
- }
1548
-
1549
- static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
1550
- const sycl::half *xi = (const sycl::half *)cxi;
1551
- float * dsti = (float *) cdsti;
1552
-
1553
- *dsti = *xi;
1554
- }
1555
-
1556
- static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
1557
- const int16_t *xi = (const int16_t *)cxi;
1558
- int16_t *dsti = (int16_t *)cdsti;
1559
-
1560
- *dsti = *xi;
1561
- }
1562
-
1563
- static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
1564
- const int32_t *xi = (const int32_t *)cxi;
1565
- int32_t *dsti = (int32_t *)cdsti;
1566
-
1567
- *dsti = *xi;
1568
- }
1569
-
1570
- template <cpy_kernel_t cpy_1>
1571
- static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
1572
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
1573
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
1574
- const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
1575
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1576
- item_ct1.get_local_id(2);
1577
-
1578
- if (i >= ne) {
1579
- return;
1580
- }
1581
-
1582
- // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
1583
- // then combine those indices with the corresponding byte offsets to get the total offsets
1584
- const int i03 = i/(ne00 * ne01 * ne02);
1585
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
1586
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
1587
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
1588
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
1589
-
1590
- const int i13 = i/(ne10 * ne11 * ne12);
1591
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
1592
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
1593
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
1594
- const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
1595
-
1596
- cpy_1(cx + x_offset, cdst + dst_offset);
1597
- }
1598
-
1599
- static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
1600
- const float * xi = (const float *) cxi;
1601
- block_q8_0 * dsti = (block_q8_0 *) cdsti;
1602
-
1603
- float amax = 0.0f; // absolute max
1604
-
1605
- for (int j = 0; j < QK8_0; j++) {
1606
- const float v = xi[j];
1607
- amax = sycl::fmax(amax, sycl::fabs((float)v));
1608
- }
1609
-
1610
- const float d = amax / ((1 << 7) - 1);
1611
- const float id = d ? 1.0f/d : 0.0f;
1612
-
1613
- dsti->d = d;
1614
-
1615
- for (int j = 0; j < QK8_0; ++j) {
1616
- const float x0 = xi[j]*id;
1617
-
1618
- dsti->qs[j] = sycl::round((float)x0);
1619
- }
1620
- }
1621
-
1622
- static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
1623
- const float * xi = (const float *) cxi;
1624
- block_q4_0 * dsti = (block_q4_0 *) cdsti;
1625
-
1626
- float amax = 0.0f;
1627
- float vmax = 0.0f;
1628
-
1629
- for (int j = 0; j < QK4_0; ++j) {
1630
- const float v = xi[j];
1631
- if (amax < sycl::fabs((float)v)) {
1632
- amax = sycl::fabs((float)v);
1633
- vmax = v;
1634
- }
1635
- }
1636
-
1637
- const float d = vmax / -8;
1638
- const float id = d ? 1.0f/d : 0.0f;
1639
-
1640
- dsti->d = d;
1641
-
1642
- for (int j = 0; j < QK4_0/2; ++j) {
1643
- const float x0 = xi[0 + j]*id;
1644
- const float x1 = xi[QK4_0/2 + j]*id;
1645
-
1646
- const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
1647
- const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
1648
-
1649
- dsti->qs[j] = xi0;
1650
- dsti->qs[j] |= xi1 << 4;
1651
- }
1652
- }
1653
-
1654
- static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
1655
- const float * xi = (const float *) cxi;
1656
- block_q4_1 * dsti = (block_q4_1 *) cdsti;
1657
-
1658
- float vmin = FLT_MAX;
1659
- float vmax = -FLT_MAX;
1660
-
1661
- for (int j = 0; j < QK4_1; ++j) {
1662
- const float v = xi[j];
1663
-
1664
- if (v < vmin) vmin = v;
1665
- if (v > vmax) vmax = v;
1666
- }
1667
-
1668
- const float d = (vmax - vmin) / ((1 << 4) - 1);
1669
- const float id = d ? 1.0f/d : 0.0f;
1670
-
1671
- dsti->dm.x() = d;
1672
- dsti->dm.y() = vmin;
1673
-
1674
- for (int j = 0; j < QK4_1/2; ++j) {
1675
- const float x0 = (xi[0 + j] - vmin)*id;
1676
- const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
1677
-
1678
- const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
1679
- const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
1680
-
1681
- dsti->qs[j] = xi0;
1682
- dsti->qs[j] |= xi1 << 4;
1683
- }
1684
- }
1685
-
1686
- template <cpy_kernel_t cpy_blck, int qk>
1687
- static void cpy_f32_q(const char * cx, char * cdst, const int ne,
1688
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
1689
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
1690
- const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
1691
- const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1692
- item_ct1.get_local_id(2)) *
1693
- qk;
1694
-
1695
- if (i >= ne) {
1696
- return;
1697
- }
1698
-
1699
- const int i03 = i/(ne00 * ne01 * ne02);
1700
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
1701
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
1702
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
1703
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
1704
-
1705
- const int i13 = i/(ne10 * ne11 * ne12);
1706
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
1707
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
1708
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
1709
- const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
1710
-
1711
- cpy_blck(cx + x_offset, cdst + dst_offset);
1712
- }
1713
-
1714
1470
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
1715
1471
  const sycl::nd_item<3> &item_ct1) {
1716
1472
  const int row = item_ct1.get_group(1);
@@ -1896,81 +1652,6 @@ static void pool2d_nchw_kernel(
1896
1652
  o_ptr[cur_oh * ow + cur_ow] = res;
1897
1653
  }
1898
1654
 
1899
- template <int qk, int qr, dequantize_kernel_t dq>
1900
- static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
1901
- ggml_tensor *dst, const void *src0_dd,
1902
- const int32_t *src1_dd, float *dst_dd,
1903
- queue_ptr stream) {
1904
-
1905
- GGML_TENSOR_BINARY_OP_LOCALS
1906
-
1907
- const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
1908
- const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
1909
- const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
1910
-
1911
- // strides in elements
1912
- //const size_t s0 = nb0 / ggml_element_size(dst);
1913
- const size_t s1 = nb1 / ggml_element_size(dst);
1914
- const size_t s2 = nb2 / ggml_element_size(dst);
1915
- const size_t s3 = nb3 / ggml_element_size(dst);
1916
-
1917
- const size_t s10 = nb10 / ggml_element_size(src1);
1918
- const size_t s11 = nb11 / ggml_element_size(src1);
1919
- const size_t s12 = nb12 / ggml_element_size(src1);
1920
- //const size_t s13 = nb13 / ggml_element_size(src1);
1921
-
1922
- GGML_ASSERT(ne00 % 2 == 0);
1923
-
1924
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
1925
- [=](sycl::nd_item<3> item_ct1) {
1926
- k_get_rows<qk, qr, dq>(
1927
- src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
1928
- s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
1929
- });
1930
-
1931
- GGML_UNUSED(dst);
1932
- GGML_UNUSED(ctx);
1933
- }
1934
-
1935
- template <typename src0_t>
1936
- static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
1937
- const ggml_tensor *src1, ggml_tensor *dst,
1938
- const src0_t *src0_dd, const int32_t *src1_dd,
1939
- float *dst_dd, queue_ptr stream) {
1940
-
1941
- GGML_TENSOR_BINARY_OP_LOCALS
1942
-
1943
- const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
1944
- const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
1945
- const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
1946
-
1947
- // strides in elements
1948
- //const size_t s0 = nb0 / ggml_element_size(dst);
1949
- const size_t s1 = nb1 / ggml_element_size(dst);
1950
- const size_t s2 = nb2 / ggml_element_size(dst);
1951
- const size_t s3 = nb3 / ggml_element_size(dst);
1952
-
1953
- const size_t s10 = nb10 / ggml_element_size(src1);
1954
- const size_t s11 = nb11 / ggml_element_size(src1);
1955
- const size_t s12 = nb12 / ggml_element_size(src1);
1956
- //const size_t s13 = nb13 / ggml_element_size(src1);
1957
-
1958
- {
1959
- dpct::has_capability_or_fail(stream->get_device(),
1960
- {sycl::aspect::fp16});
1961
-
1962
- stream->parallel_for(
1963
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1964
- [=](sycl::nd_item<3> item_ct1) {
1965
- k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
1966
- s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
1967
- });
1968
- }
1969
-
1970
- GGML_UNUSED(dst);
1971
- GGML_UNUSED(ctx);
1972
- }
1973
-
1974
1655
  static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
1975
1656
  const int ky, const int kx_padded,
1976
1657
  queue_ptr stream) {
@@ -2034,231 +1715,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
2034
1715
  }
2035
1716
  }
2036
1717
 
2037
- static void
2038
- ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00,
2039
- const int ne01, const int ne02, const int nb00,
2040
- const int nb01, const int nb02, const int nb03,
2041
- const int ne10, const int ne11, const int ne12,
2042
- const int nb10, const int nb11, const int nb12,
2043
- const int nb13, queue_ptr stream) {
2044
-
2045
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2046
- {
2047
- dpct::has_capability_or_fail(stream->get_device(),
2048
- {sycl::aspect::fp16});
2049
1718
 
2050
- stream->parallel_for(
2051
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2052
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2053
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2054
- [=](sycl::nd_item<3> item_ct1) {
2055
- cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00,
2056
- nb01, nb02, nb03, ne10, ne11, ne12,
2057
- nb10, nb11, nb12, nb13, item_ct1);
2058
- });
2059
- }
2060
- }
2061
-
2062
- static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
2063
- const int ne00, const int ne01,
2064
- const int ne02, const int nb00,
2065
- const int nb01, const int nb02,
2066
- const int nb03, const int ne10,
2067
- const int ne11, const int ne12,
2068
- const int nb10, const int nb11,
2069
- const int nb12, const int nb13,
2070
- queue_ptr stream) {
2071
-
2072
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2073
- {
2074
- dpct::has_capability_or_fail(stream->get_device(),
2075
- {sycl::aspect::fp16});
2076
-
2077
- stream->parallel_for(
2078
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2079
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2080
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2081
- [=](sycl::nd_item<3> item_ct1) {
2082
- cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2083
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2084
- item_ct1);
2085
- });
2086
- }
2087
- }
2088
-
2089
- static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne,
2090
- const int ne00, const int ne01,
2091
- const int ne02, const int nb00,
2092
- const int nb01, const int nb02,
2093
- const int nb03, const int ne10,
2094
- const int ne11, const int ne12,
2095
- const int nb10, const int nb11,
2096
- const int nb12, const int nb13,
2097
- queue_ptr stream) {
2098
-
2099
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2100
- {
2101
- dpct::has_capability_or_fail(stream->get_device(),
2102
- {sycl::aspect::fp16});
2103
-
2104
- stream->parallel_for(
2105
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2106
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2107
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2108
- [=](sycl::nd_item<3> item_ct1) {
2109
- cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2110
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2111
- item_ct1);
2112
- });
2113
- }
2114
- }
2115
-
2116
- static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne,
2117
- const int ne00, const int ne01,
2118
- const int ne02, const int nb00,
2119
- const int nb01, const int nb02,
2120
- const int nb03, const int ne10,
2121
- const int ne11, const int ne12,
2122
- const int nb10, const int nb11,
2123
- const int nb12, const int nb13,
2124
- queue_ptr stream) {
2125
-
2126
- GGML_ASSERT(ne % QK8_0 == 0);
2127
- const int num_blocks = ne / QK8_0;
2128
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2129
- sycl::range<3>(1, 1, 1)),
2130
- [=](sycl::nd_item<3> item_ct1) {
2131
- cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(
2132
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2133
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2134
- item_ct1);
2135
- });
2136
- }
2137
-
2138
- static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne,
2139
- const int ne00, const int ne01,
2140
- const int ne02, const int nb00,
2141
- const int nb01, const int nb02,
2142
- const int nb03, const int ne10,
2143
- const int ne11, const int ne12,
2144
- const int nb10, const int nb11,
2145
- const int nb12, const int nb13,
2146
- queue_ptr stream) {
2147
-
2148
- GGML_ASSERT(ne % QK4_0 == 0);
2149
- const int num_blocks = ne / QK4_0;
2150
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2151
- sycl::range<3>(1, 1, 1)),
2152
- [=](sycl::nd_item<3> item_ct1) {
2153
- cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(
2154
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2155
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2156
- item_ct1);
2157
- });
2158
- }
2159
-
2160
- static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne,
2161
- const int ne00, const int ne01,
2162
- const int ne02, const int nb00,
2163
- const int nb01, const int nb02,
2164
- const int nb03, const int ne10,
2165
- const int ne11, const int ne12,
2166
- const int nb10, const int nb11,
2167
- const int nb12, const int nb13,
2168
- queue_ptr stream) {
2169
-
2170
- GGML_ASSERT(ne % QK4_1 == 0);
2171
- const int num_blocks = ne / QK4_1;
2172
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2173
- sycl::range<3>(1, 1, 1)),
2174
- [=](sycl::nd_item<3> item_ct1) {
2175
- cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(
2176
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2177
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2178
- item_ct1);
2179
- });
2180
- }
2181
-
2182
- static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne,
2183
- const int ne00, const int ne01,
2184
- const int ne02, const int nb00,
2185
- const int nb01, const int nb02,
2186
- const int nb03, const int ne10,
2187
- const int ne11, const int ne12,
2188
- const int nb10, const int nb11,
2189
- const int nb12, const int nb13,
2190
- queue_ptr stream) {
2191
-
2192
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2193
- {
2194
- dpct::has_capability_or_fail(stream->get_device(),
2195
- {sycl::aspect::fp16});
2196
-
2197
- stream->parallel_for(
2198
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2199
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2200
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2201
- [=](sycl::nd_item<3> item_ct1) {
2202
- cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2203
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2204
- item_ct1);
2205
- });
2206
- }
2207
- }
2208
-
2209
- static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne,
2210
- const int ne00, const int ne01,
2211
- const int ne02, const int nb00,
2212
- const int nb01, const int nb02,
2213
- const int nb03, const int ne10,
2214
- const int ne11, const int ne12,
2215
- const int nb10, const int nb11,
2216
- const int nb12, const int nb13,
2217
- queue_ptr stream) {
2218
-
2219
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2220
- {
2221
- // dpct::has_capability_or_fail(stream->get_device(),
2222
- // {sycl::aspect::fp16});
2223
-
2224
- stream->parallel_for(
2225
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2226
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2227
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2228
- [=](sycl::nd_item<3> item_ct1) {
2229
- cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2230
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2231
- item_ct1);
2232
- });
2233
- }
2234
- }
2235
-
2236
- static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne,
2237
- const int ne00, const int ne01,
2238
- const int ne02, const int nb00,
2239
- const int nb01, const int nb02,
2240
- const int nb03, const int ne10,
2241
- const int ne11, const int ne12,
2242
- const int nb10, const int nb11,
2243
- const int nb12, const int nb13,
2244
- queue_ptr stream) {
2245
-
2246
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2247
- {
2248
- // dpct::has_capability_or_fail(stream->get_device(),
2249
- // {sycl::aspect::fp16});
2250
-
2251
- stream->parallel_for(
2252
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2253
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2254
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2255
- [=](sycl::nd_item<3> item_ct1) {
2256
- cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2257
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2258
- item_ct1);
2259
- });
2260
- }
2261
- }
2262
1719
 
2263
1720
  static void scale_f32_sycl(const float *x, float *dst, const float scale,
2264
1721
  const int k, queue_ptr stream) {
@@ -2494,52 +1951,6 @@ catch (sycl::exception const &exc) {
2494
1951
  std::exit(1);
2495
1952
  }
2496
1953
 
2497
- static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2498
- const ggml_tensor *src1, ggml_tensor *dst,
2499
- const float *src0_d, const float *src1_d,
2500
- float *dst_d, const queue_ptr &stream) {
2501
-
2502
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
2503
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
2504
-
2505
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
2506
- GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
2507
- GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
2508
-
2509
- const int32_t * src1_i32 = (const int32_t *) src1_d;
2510
-
2511
- switch (src0->type) {
2512
- case GGML_TYPE_F16:
2513
- get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
2514
- src1_i32, dst_d, stream);
2515
- break;
2516
- case GGML_TYPE_F32:
2517
- get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2518
- break;
2519
- case GGML_TYPE_Q4_0:
2520
- get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2521
- break;
2522
- case GGML_TYPE_Q4_1:
2523
- get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2524
- break;
2525
- case GGML_TYPE_Q5_0:
2526
- get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2527
- break;
2528
- case GGML_TYPE_Q5_1:
2529
- get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2530
- break;
2531
- case GGML_TYPE_Q8_0:
2532
- get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2533
- break;
2534
- default:
2535
- // TODO: k-quants
2536
- GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
2537
- GGML_ABORT("fatal error");
2538
- break;
2539
- }
2540
- }
2541
-
2542
-
2543
1954
  static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2544
1955
  const ggml_tensor *src1, ggml_tensor *dst,
2545
1956
  const float *src0_d, const float *src1_d,
@@ -2589,11 +2000,10 @@ inline void ggml_sycl_op_mul_mat_sycl(
2589
2000
  if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
2590
2001
  use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] &&
2591
2002
  dst->op_params[0] == GGML_PREC_DEFAULT) {
2592
-
2593
2003
  // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n");
2594
2004
  ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
2595
2005
  if (src0->type != GGML_TYPE_F16) {
2596
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type);
2006
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst);
2597
2007
  GGML_ASSERT(to_fp16_sycl != nullptr);
2598
2008
  size_t ne = row_diff*ne00;
2599
2009
  src0_as_f16.alloc(ne);
@@ -2605,7 +2015,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
2605
2015
 
2606
2016
  ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());
2607
2017
  if (src1->type != GGML_TYPE_F16) {
2608
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
2018
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
2609
2019
  GGML_ASSERT(to_fp16_sycl != nullptr);
2610
2020
  size_t ne = src1_ncols*ne10;
2611
2021
  src1_as_f16.alloc(ne);
@@ -2626,13 +2036,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
2626
2036
  src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2627
2037
  dst_f16.get(), dpct::library_data_t::real_half, ldc,
2628
2038
  dpct::library_data_t::real_half)));
2629
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2039
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2630
2040
  to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2631
2041
  #else
2632
2042
  auto dnnl_stream = ctx.stream_dnnl(stream);
2633
2043
  DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2634
2044
  src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
2635
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2045
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2636
2046
  to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2637
2047
  #endif
2638
2048
  }
@@ -2641,13 +2051,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
2641
2051
  ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());
2642
2052
  ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());
2643
2053
  if (src0->type != GGML_TYPE_F32) {
2644
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type);
2054
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst);
2645
2055
  GGML_ASSERT(to_fp32_sycl != nullptr);
2646
2056
  src0_ddq_as_f32.alloc(row_diff*ne00);
2647
2057
  to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
2648
2058
  }
2649
2059
  if (src1->type != GGML_TYPE_F32) {
2650
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type);
2060
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst);
2651
2061
  GGML_ASSERT(to_fp32_sycl != nullptr);
2652
2062
  src1_ddq_as_f32.alloc(src1_ncols*ne10);
2653
2063
  to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
@@ -3085,7 +2495,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
3085
2495
  for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
3086
2496
  const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
3087
2497
  const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
3088
-
3089
2498
  for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
3090
2499
  if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
3091
2500
  continue;
@@ -3393,7 +2802,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3393
2802
  // convert src1 to fp16
3394
2803
  ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
3395
2804
  if (src1->type != GGML_TYPE_F16) {
3396
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
2805
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
3397
2806
  const int64_t ne_src1 = ggml_nelements(src1);
3398
2807
  src1_f16_alloc.alloc(ne_src1);
3399
2808
  GGML_ASSERT(to_fp16_sycl != nullptr);
@@ -3509,6 +2918,7 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
3509
2918
  }
3510
2919
 
3511
2920
  static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2921
+
3512
2922
  const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
3513
2923
  int64_t min_compute_capability = INT_MAX;
3514
2924
 
@@ -3570,6 +2980,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3570
2980
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3571
2981
  } else if (use_dequantize_mul_mat_vec) {
3572
2982
  ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
2983
+ // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
3573
2984
  } else if (use_mul_mat_vec_q) {
3574
2985
  ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
3575
2986
  } else if (use_mul_mat_q) {
@@ -3822,58 +3233,6 @@ static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
3822
3233
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp);
3823
3234
  }
3824
3235
 
3825
- static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3826
- ggml_tensor *dst) try {
3827
- const int64_t ne = ggml_nelements(src0);
3828
- GGML_ASSERT(ne == ggml_nelements(src1));
3829
-
3830
- GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
3831
- GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
3832
-
3833
- GGML_TENSOR_BINARY_OP_LOCALS01;
3834
-
3835
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
3836
- queue_ptr main_stream = ctx.stream();
3837
-
3838
- char * src0_ddc = (char *) src0->data;
3839
- char * src1_ddc = (char *) src1->data;
3840
-
3841
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
3842
- ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3843
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
3844
- ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3845
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
3846
- ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3847
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
3848
- ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3849
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
3850
- ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3851
- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
3852
- ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3853
- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
3854
- ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3855
- } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
3856
- ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3857
- } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
3858
- ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3859
- } else {
3860
- GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__,
3861
- ggml_type_name(src0->type), ggml_type_name(src1->type));
3862
- GGML_ABORT("fatal error");
3863
- }
3864
- GGML_UNUSED(dst);
3865
- }
3866
- catch (sycl::exception const &exc) {
3867
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3868
- << ", line:" << __LINE__ << std::endl;
3869
- std::exit(1);
3870
- }
3871
-
3872
- static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3873
- // TODO: why do we pass dst as src1 here?
3874
- ggml_sycl_cpy(ctx, dst->src[0], dst, nullptr);
3875
- }
3876
-
3877
3236
  static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3878
3237
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
3879
3238
  }
@@ -4070,7 +3429,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
4070
3429
  ggml_sycl_clamp(ctx, dst);
4071
3430
  break;
4072
3431
  case GGML_OP_CPY:
4073
- ggml_sycl_cpy(ctx, dst->src[0], dst->src[1], dst);
3432
+ ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]);
4074
3433
  break;
4075
3434
  case GGML_OP_CONT:
4076
3435
  ggml_sycl_dup(ctx, dst);
@@ -4251,10 +3610,72 @@ catch (sycl::exception const &exc) {
4251
3610
  std::exit(1);
4252
3611
  }
4253
3612
 
3613
+ void reorder_qw(char *data_device, const int ncols, const int nrows,
3614
+ size_t size, size_t offset, dpct::queue_ptr stream) {
3615
+ auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
3616
+ SYCL_CHECK(
3617
+ CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
3618
+ .wait()));
3619
+ GGML_ASSERT((size % sizeof(block_q4_0) == 0));
3620
+ GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
3621
+ int offset_blks = offset / sizeof(block_q4_0);
3622
+ auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;;
3623
+ auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
3624
+
3625
+ stream->parallel_for(
3626
+ size / sizeof(block_q4_0),
3627
+ [=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
3628
+ const block_q4_0* x = (const block_q4_0*)tmp_buf;
3629
+ const int ib = i;
3630
+
3631
+ for (int j = 0; j < QK4_0/2; j ++)
3632
+ {
3633
+ *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
3634
+ }
3635
+ *(d_ptr + ib) = x[ib].d;
3636
+ });
3637
+
3638
+ sycl::free(tmp_buf, *stream);
3639
+ }
3640
+
3641
+ void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
3642
+ char*data_device = (char*)src0->data;
3643
+ size_t ncols = src0->ne[0];
3644
+ size_t nrows = src0->ne[1];
3645
+ size_t size = ggml_nbytes(src0);
3646
+
3647
+ reorder_qw(data_device, ncols, nrows, size, 0, stream);
3648
+ }
3649
+
3650
+ void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
3651
+ ggml_tensor *src0 = dst->src[0];
3652
+ ggml_tensor *src1 = dst->src[1];
3653
+
3654
+ if (dst->op == GGML_OP_MUL_MAT && src0->type == GGML_TYPE_Q4_0 &&
3655
+ src1->ne[2]==1 && src1->ne[3]==1) {
3656
+ reorder_qw(src0, stream);
3657
+ ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
3658
+ GGML_ASSERT(extra);
3659
+ extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.
3660
+ }
3661
+ }
3662
+
3663
+ void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
3664
+ dpct::queue_ptr stream = ctx->stream();
3665
+ if (ctx->optimized_graph) {
3666
+ return;
3667
+ }
3668
+ ctx->optimized_graph = true;
3669
+
3670
+ for (int i = 0; i < cgraph->n_nodes; i++) {
3671
+ if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream);
3672
+ }
3673
+ }
4254
3674
  static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
4255
3675
  ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4256
3676
  ggml_sycl_set_main_device(sycl_ctx->device);
4257
3677
 
3678
+ if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx);
4258
3679
 
4259
3680
  for (int i = 0; i < cgraph->n_nodes; i++) {
4260
3681
  ggml_tensor * node = cgraph->nodes[i];
@@ -4444,7 +3865,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4444
3865
  case GGML_UNARY_OP_GELU_QUICK:
4445
3866
  case GGML_UNARY_OP_TANH:
4446
3867
  case GGML_UNARY_OP_EXP:
4447
- return ggml_is_contiguous(op->src[0]);
3868
+ return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32);
4448
3869
  default:
4449
3870
  return false;
4450
3871
  }
@@ -4522,6 +3943,30 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4522
3943
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
4523
3944
  return true;
4524
3945
  }
3946
+ if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
3947
+ return true;
3948
+ }
3949
+ if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
3950
+ return true;
3951
+ }
3952
+ if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
3953
+ return true;
3954
+ }
3955
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
3956
+ return true;
3957
+ }
3958
+ if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
3959
+ return true;
3960
+ }
3961
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
3962
+ return true;
3963
+ }
3964
+ if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
3965
+ return true;
3966
+ }
3967
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
3968
+ return true;
3969
+ }
4525
3970
  return false;
4526
3971
  } break;
4527
3972
  case GGML_OP_CONCAT:
@@ -4537,23 +3982,24 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4537
3982
  case GGML_OP_VIEW:
4538
3983
  case GGML_OP_PERMUTE:
4539
3984
  case GGML_OP_TRANSPOSE:
3985
+ return true;
4540
3986
  case GGML_OP_ADD:
4541
3987
  case GGML_OP_ADD1:
4542
- case GGML_OP_LOG:
4543
3988
  case GGML_OP_SUB:
4544
3989
  case GGML_OP_MUL:
4545
3990
  case GGML_OP_DIV:
4546
- return true;
4547
- case GGML_OP_NORM:
4548
- case GGML_OP_RMS_NORM:
4549
- case GGML_OP_GROUP_NORM:
4550
- return ggml_is_contiguous(op->src[0]);
4551
- case GGML_OP_SCALE:
4552
3991
  case GGML_OP_SQR:
4553
3992
  case GGML_OP_SQRT:
4554
3993
  case GGML_OP_SIN:
4555
3994
  case GGML_OP_COS:
4556
3995
  case GGML_OP_CLAMP:
3996
+ case GGML_OP_LOG:
3997
+ return (op->src[0]->type == GGML_TYPE_F32);
3998
+ case GGML_OP_NORM:
3999
+ case GGML_OP_RMS_NORM:
4000
+ case GGML_OP_GROUP_NORM:
4001
+ return ggml_is_contiguous(op->src[0]);
4002
+ case GGML_OP_SCALE:
4557
4003
  return true;
4558
4004
  case GGML_OP_CONT:
4559
4005
  return op->src[0]->type != GGML_TYPE_BF16;