@fugood/llama.node 0.3.13 → 0.3.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +1 -1
  18. package/package.json +1 -1
  19. package/src/LlamaContext.cpp +98 -76
  20. package/src/LlamaContext.h +1 -1
  21. package/src/common.hpp +1 -2
  22. package/src/llama.cpp/.github/workflows/build.yml +60 -10
  23. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  24. package/src/llama.cpp/common/CMakeLists.txt +3 -3
  25. package/src/llama.cpp/common/arg.cpp +112 -11
  26. package/src/llama.cpp/common/chat.cpp +960 -266
  27. package/src/llama.cpp/common/chat.h +135 -0
  28. package/src/llama.cpp/common/common.cpp +27 -171
  29. package/src/llama.cpp/common/common.h +27 -67
  30. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  31. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  32. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
  33. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  34. package/src/llama.cpp/common/sampling.cpp +45 -7
  35. package/src/llama.cpp/common/speculative.cpp +6 -5
  36. package/src/llama.cpp/common/speculative.h +1 -1
  37. package/src/llama.cpp/docs/build.md +45 -7
  38. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
  39. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
  40. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  41. package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -3
  42. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
  43. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  44. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  45. package/src/llama.cpp/examples/llava/clip.h +19 -3
  46. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  47. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  48. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  49. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
  50. package/src/llama.cpp/examples/main/main.cpp +73 -28
  51. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
  52. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
  53. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  54. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  55. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  56. package/src/llama.cpp/examples/run/run.cpp +110 -67
  57. package/src/llama.cpp/examples/server/server.cpp +82 -87
  58. package/src/llama.cpp/examples/server/utils.hpp +94 -107
  59. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  60. package/src/llama.cpp/examples/tts/tts.cpp +251 -142
  61. package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
  62. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  63. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  64. package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
  65. package/src/llama.cpp/ggml/include/ggml.h +5 -1
  66. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  67. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  68. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  69. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  70. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  71. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
  72. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  73. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
  74. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  75. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  76. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  77. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
  78. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1396 -386
  79. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1432 -151
  80. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
  81. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  82. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  83. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  84. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  85. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
  86. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  87. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  89. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
  90. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
  91. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  92. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +220 -116
  93. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  94. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  95. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
  96. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  97. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  98. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
  99. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  100. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  101. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  102. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  103. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  104. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
  105. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
  106. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  107. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +168 -721
  108. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
  109. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
  110. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  111. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  112. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +146 -42
  113. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +13 -3
  114. package/src/llama.cpp/ggml/src/ggml.c +8 -3
  115. package/src/llama.cpp/include/llama.h +19 -5
  116. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  117. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  118. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  119. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  120. package/src/llama.cpp/requirements.txt +1 -0
  121. package/src/llama.cpp/src/llama-arch.cpp +21 -0
  122. package/src/llama.cpp/src/llama-arch.h +1 -0
  123. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  124. package/src/llama.cpp/src/llama-grammar.cpp +182 -182
  125. package/src/llama.cpp/src/llama-grammar.h +12 -3
  126. package/src/llama.cpp/src/llama-kv-cache.h +1 -0
  127. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  128. package/src/llama.cpp/src/llama-model.cpp +69 -5
  129. package/src/llama.cpp/src/llama-sampling.cpp +43 -10
  130. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  131. package/src/llama.cpp/src/llama.cpp +147 -0
  132. package/src/llama.cpp/tests/test-backend-ops.cpp +166 -110
  133. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  134. package/src/llama.cpp/tests/test-chat.cpp +593 -395
  135. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  136. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  137. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  138. package/src/llama.cpp/common/chat.hpp +0 -55
  139. /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
@@ -39,8 +39,13 @@
39
39
  #include "ggml-sycl/backend.hpp"
40
40
  #include "ggml-sycl/presets.hpp"
41
41
  #include "ggml-sycl/gemm.hpp"
42
+ #include "ggml-sycl/sycl_hw.hpp"
43
+ #include "ggml-sycl/getrows.hpp"
44
+ #include "ggml.h"
42
45
 
43
46
  static bool g_sycl_loaded = false;
47
+ int g_ggml_sycl_debug = 0;
48
+ int g_ggml_sycl_disable_optimize = 0;
44
49
 
45
50
  static ggml_sycl_device_info ggml_sycl_init() {
46
51
  ggml_sycl_device_info info = {};
@@ -63,14 +68,18 @@ static ggml_sycl_device_info ggml_sycl_init() {
63
68
  for (int i = 0; i < info.device_count; ++i) {
64
69
  info.devices[i].vmm = 0;
65
70
  dpct::device_info prop;
71
+ sycl::device device = dpct::dev_mgr::instance().get_device(i);
72
+
66
73
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
67
- prop, dpct::dev_mgr::instance().get_device(i))));
74
+ prop, device)));
68
75
 
69
76
  info.default_tensor_split[i] = total_vram;
70
77
  total_vram += prop.get_global_mem_size();
71
78
 
72
79
  info.devices[i].cc =
73
80
  100 * prop.get_major_version() + 10 * prop.get_minor_version();
81
+ info.devices[i].hw_info = get_device_hw_info(&device);
82
+ info.devices[i].opt_feature = check_gpu_optimize_feature(info.devices[i].hw_info.arch);
74
83
 
75
84
  info.max_work_group_sizes[i] = prop.get_max_work_group_size();
76
85
  }
@@ -109,6 +118,27 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
109
118
  global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
110
119
  }
111
120
 
121
+ void print_device_opt_feature(int device_count) {
122
+ GGML_LOG_INFO("SYCL Optimization Feature:\n");
123
+ GGML_LOG_INFO(
124
+ "|ID| Device Type|Reorder|\n");
125
+ GGML_LOG_INFO(
126
+ "|--|-------------------|-------|\n");
127
+ std::map<std::string, size_t> DeviceNums;
128
+ for (int id = 0; id < device_count; ++id) {
129
+ sycl::device device = dpct::dev_mgr::instance().get_device(id);
130
+ std::string backend_type = get_device_backend_and_type(device);
131
+ int type_id = DeviceNums[backend_type]++;
132
+ std::stringstream device_type;
133
+ device_type << "[" << backend_type << ":" << std::to_string(type_id)
134
+ << "]";
135
+ std::string device_type_s = device_type.str();
136
+ device_type_s = std::regex_replace(device_type_s, std::regex("ext_oneapi_"), "");
137
+ GGML_LOG_INFO("|%2d|%19s|%7s|\n", id, device_type_s.c_str(),
138
+ ggml_sycl_info().devices[id].opt_feature.reorder ? "Y": "N");
139
+ }
140
+
141
+ }
112
142
  void ggml_backend_sycl_print_sycl_devices() {
113
143
  GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
114
144
  int device_count = dpct::dev_mgr::instance().device_count();
@@ -137,6 +167,8 @@ void ggml_backend_sycl_print_sycl_devices() {
137
167
  << "]";
138
168
  print_device_detail(id, device, device_type.str());
139
169
  }
170
+
171
+ print_device_opt_feature(device_count);
140
172
  }
141
173
 
142
174
  static inline int get_sycl_env(const char *env_name, int default_val) {
@@ -157,18 +189,22 @@ static void ggml_check_sycl() try {
157
189
  static bool initialized = false;
158
190
 
159
191
  if (!initialized) {
160
- GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
161
192
  g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
162
- GGML_LOG_INFO("GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
193
+ g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
194
+ GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
195
+ GGML_LOG_INFO("Running with Environment Variables:\n");
196
+ GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
197
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
198
+ GGML_LOG_INFO("Build with Macros:\n");
163
199
  #if defined(GGML_SYCL_FORCE_MMQ)
164
- GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: yes\n");
200
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
165
201
  #else
166
- GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: no\n");
202
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
167
203
  #endif
168
204
  #if defined(GGML_SYCL_F16)
169
- GGML_LOG_INFO("GGML_SYCL_F16: yes\n");
205
+ GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
170
206
  #else
171
- GGML_LOG_INFO("GGML_SYCL_F16: no\n");
207
+ GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
172
208
  #endif
173
209
 
174
210
  /* NOT REMOVE, keep it for next optimize for XMX.
@@ -240,19 +276,27 @@ struct ggml_backend_sycl_buffer_context {
240
276
  void * dev_ptr = nullptr;
241
277
  queue_ptr stream;
242
278
  std::string name;
279
+ optimize_feature opt_feature;
280
+ std::vector<ggml_tensor_extra_gpu *> tensor_extras;
243
281
 
244
- ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
282
+ ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
245
283
  device(device), dev_ptr(dev_ptr), stream(stream) {
246
284
  check_allow_gpu_index(device);
247
285
  name = (GGML_SYCL_NAME + std::to_string(device));
286
+ opt_feature = ggml_sycl_info().devices[device].opt_feature;
248
287
  }
249
288
 
250
-
251
289
  ~ggml_backend_sycl_buffer_context() {
252
290
  if (dev_ptr != nullptr) {
253
291
  ggml_sycl_set_device(device);
254
292
  SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
255
293
  }
294
+
295
+ //release extra used by tensors
296
+ for (ggml_tensor_extra_gpu * extra : tensor_extras) {
297
+ release_extra_gpu(extra);
298
+ }
299
+
256
300
  }
257
301
  };
258
302
 
@@ -280,16 +324,19 @@ static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
280
324
  return ctx->dev_ptr;
281
325
  }
282
326
 
283
- static void
327
+ static enum ggml_status
284
328
  ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
285
329
  ggml_tensor *tensor) try {
286
330
  ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
287
331
 
288
332
  if (tensor->view_src != NULL) {
289
333
  assert(tensor->view_src->buffer->buft == buffer->buft);
290
- return;
334
+ return GGML_STATUS_SUCCESS;
291
335
  }
292
336
 
337
+ ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
338
+ tensor->extra = extra;
339
+ ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
293
340
 
294
341
  if (ggml_is_quantized(tensor->type)) {
295
342
  // initialize padding to 0 to avoid possible NaN values
@@ -302,6 +349,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
302
349
  padded_size - original_size).wait()));
303
350
  }
304
351
  }
352
+ return GGML_STATUS_SUCCESS;
305
353
  }
306
354
  catch (sycl::exception const &exc) {
307
355
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -315,7 +363,6 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
315
363
  size_t size) try {
316
364
 
317
365
  ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
318
-
319
366
  ggml_sycl_set_device(ctx->device);
320
367
  auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
321
368
  SYCL_CHECK(
@@ -659,32 +706,7 @@ struct ggml_backend_sycl_split_buffer_type_context {
659
706
  struct ggml_backend_sycl_split_buffer_context {
660
707
  ~ggml_backend_sycl_split_buffer_context() try {
661
708
  for (ggml_tensor_extra_gpu * extra : tensor_extras) {
662
- for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
663
- for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
664
- if (extra->events[i][is] != nullptr) {
665
- /*
666
- DPCT1009:206: SYCL uses exceptions to report errors and
667
- does not use the error codes. The original code was
668
- commented out and a warning string was inserted. You
669
- need to rewrite this code.
670
- */
671
- SYCL_CHECK(CHECK_TRY_ERROR(
672
- dpct::destroy_event(extra->events[i][is])));
673
- }
674
- }
675
- if (extra->data_device[i] != nullptr) {
676
- /*
677
- DPCT1009:207: SYCL uses exceptions to report errors and does
678
- not use the error codes. The original code was commented out
679
- and a warning string was inserted. You need to rewrite this
680
- code.
681
- */
682
- ggml_sycl_set_device(i);
683
- SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(
684
- extra->data_device[i], *(streams[i]))));
685
- }
686
- }
687
- delete extra;
709
+ release_extra_gpu(extra, streams);
688
710
  }
689
711
  }
690
712
  catch (sycl::exception const &exc) {
@@ -709,7 +731,7 @@ static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buff
709
731
  GGML_UNUSED(buffer);
710
732
  }
711
733
 
712
- static void
734
+ static enum ggml_status
713
735
  ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
714
736
  ggml_tensor *tensor) try {
715
737
  GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
@@ -722,7 +744,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
722
744
  ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
723
745
 
724
746
  ctx->tensor_extras.push_back(extra);
725
- ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
747
+ ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
726
748
 
727
749
  for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
728
750
  int64_t row_low, row_high;
@@ -784,6 +806,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
784
806
  }
785
807
  }
786
808
  tensor->extra = extra;
809
+ return GGML_STATUS_SUCCESS;
787
810
  }
788
811
  catch (sycl::exception const &exc) {
789
812
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -1263,8 +1286,6 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(q
1263
1286
  // struct ggml_sycl_pool_vmm : public ggml_sycl_pool
1264
1287
 
1265
1288
  /// kernels
1266
-
1267
- typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
1268
1289
  typedef void (*ggml_sycl_op_mul_mat_t)(
1269
1290
  ggml_backend_sycl_context & ctx,
1270
1291
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -1336,83 +1357,6 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
1336
1357
  reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
1337
1358
  }
1338
1359
 
1339
- template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
1340
- static void k_get_rows(
1341
- const void * src0, const int32_t * src1, dst_t * dst,
1342
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1343
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1344
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1345
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1346
- size_t s10, size_t s11, size_t s12,
1347
- const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
1348
-
1349
- const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
1350
- item_ct1.get_local_id(2)) *
1351
- 2;
1352
- const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1353
- item_ct1.get_local_id(1);
1354
- const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1355
- item_ct1.get_local_id(0)) /
1356
- ne12;
1357
- const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1358
- item_ct1.get_local_id(0)) %
1359
- ne12;
1360
-
1361
- if (i00 >= ne00) {
1362
- return;
1363
- }
1364
-
1365
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1366
-
1367
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1368
- const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
1369
-
1370
- const int ib = i00/qk; // block index
1371
- const int iqs = (i00%qk)/qr; // quant index
1372
- const int iybs = i00 - i00%qk; // dst block start index
1373
- const int y_offset = qr == 1 ? 1 : qk/2;
1374
-
1375
- // dequantize
1376
- dfloat2 v;
1377
- dequantize_kernel(src0_row, ib, iqs, v);
1378
-
1379
- dst_row[iybs + iqs + 0] = v.x();
1380
- dst_row[iybs + iqs + y_offset] = v.y();
1381
- }
1382
-
1383
- template<typename src0_t, typename dst_t>
1384
- static void k_get_rows_float(
1385
- const src0_t * src0, const int32_t * src1, dst_t * dst,
1386
- int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1387
- /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1388
- /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1389
- /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1390
- size_t s10, size_t s11, size_t s12,
1391
- const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
1392
-
1393
- const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
1394
- item_ct1.get_local_id(2);
1395
- const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1396
- item_ct1.get_local_id(1);
1397
- const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1398
- item_ct1.get_local_id(0)) /
1399
- ne12;
1400
- const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
1401
- item_ct1.get_local_id(0)) %
1402
- ne12;
1403
-
1404
- if (i00 >= ne00) {
1405
- return;
1406
- }
1407
-
1408
- const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
1409
-
1410
- dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1411
- const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
1412
-
1413
- dst_row[i00] = src0_row[i00];
1414
- }
1415
-
1416
1360
  static void mul_mat_p021_f16_f32(
1417
1361
  const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
1418
1362
  const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1523,193 +1467,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1523
1467
  }
1524
1468
  }
1525
1469
 
1526
- static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
1527
- const float * xi = (const float *) cxi;
1528
- float * dsti = (float *) cdsti;
1529
-
1530
- *dsti = *xi;
1531
- }
1532
-
1533
- static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
1534
- const float * xi = (const float *) cxi;
1535
- sycl::half *dsti = (sycl::half *)cdsti;
1536
-
1537
- *dsti = sycl::vec<float, 1>(*xi)
1538
- .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
1539
- }
1540
-
1541
- static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
1542
- const sycl::half *xi = (const sycl::half *)cxi;
1543
- sycl::half *dsti = (sycl::half *)cdsti;
1544
-
1545
- *dsti = *xi;
1546
- }
1547
-
1548
- static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
1549
- const sycl::half *xi = (const sycl::half *)cxi;
1550
- float * dsti = (float *) cdsti;
1551
-
1552
- *dsti = *xi;
1553
- }
1554
-
1555
- static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
1556
- const int16_t *xi = (const int16_t *)cxi;
1557
- int16_t *dsti = (int16_t *)cdsti;
1558
-
1559
- *dsti = *xi;
1560
- }
1561
-
1562
- static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
1563
- const int32_t *xi = (const int32_t *)cxi;
1564
- int32_t *dsti = (int32_t *)cdsti;
1565
-
1566
- *dsti = *xi;
1567
- }
1568
-
1569
- template <cpy_kernel_t cpy_1>
1570
- static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
1571
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
1572
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
1573
- const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
1574
- const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1575
- item_ct1.get_local_id(2);
1576
-
1577
- if (i >= ne) {
1578
- return;
1579
- }
1580
-
1581
- // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
1582
- // then combine those indices with the corresponding byte offsets to get the total offsets
1583
- const int i03 = i/(ne00 * ne01 * ne02);
1584
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
1585
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
1586
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
1587
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
1588
-
1589
- const int i13 = i/(ne10 * ne11 * ne12);
1590
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
1591
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
1592
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
1593
- const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
1594
-
1595
- cpy_1(cx + x_offset, cdst + dst_offset);
1596
- }
1597
-
1598
- static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
1599
- const float * xi = (const float *) cxi;
1600
- block_q8_0 * dsti = (block_q8_0 *) cdsti;
1601
-
1602
- float amax = 0.0f; // absolute max
1603
-
1604
- for (int j = 0; j < QK8_0; j++) {
1605
- const float v = xi[j];
1606
- amax = sycl::fmax(amax, sycl::fabs((float)v));
1607
- }
1608
-
1609
- const float d = amax / ((1 << 7) - 1);
1610
- const float id = d ? 1.0f/d : 0.0f;
1611
-
1612
- dsti->d = d;
1613
-
1614
- for (int j = 0; j < QK8_0; ++j) {
1615
- const float x0 = xi[j]*id;
1616
-
1617
- dsti->qs[j] = sycl::round((float)x0);
1618
- }
1619
- }
1620
-
1621
- static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
1622
- const float * xi = (const float *) cxi;
1623
- block_q4_0 * dsti = (block_q4_0 *) cdsti;
1624
-
1625
- float amax = 0.0f;
1626
- float vmax = 0.0f;
1627
-
1628
- for (int j = 0; j < QK4_0; ++j) {
1629
- const float v = xi[j];
1630
- if (amax < sycl::fabs((float)v)) {
1631
- amax = sycl::fabs((float)v);
1632
- vmax = v;
1633
- }
1634
- }
1635
-
1636
- const float d = vmax / -8;
1637
- const float id = d ? 1.0f/d : 0.0f;
1638
-
1639
- dsti->d = d;
1640
-
1641
- for (int j = 0; j < QK4_0/2; ++j) {
1642
- const float x0 = xi[0 + j]*id;
1643
- const float x1 = xi[QK4_0/2 + j]*id;
1644
-
1645
- const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
1646
- const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
1647
-
1648
- dsti->qs[j] = xi0;
1649
- dsti->qs[j] |= xi1 << 4;
1650
- }
1651
- }
1652
-
1653
- static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
1654
- const float * xi = (const float *) cxi;
1655
- block_q4_1 * dsti = (block_q4_1 *) cdsti;
1656
-
1657
- float vmin = FLT_MAX;
1658
- float vmax = -FLT_MAX;
1659
-
1660
- for (int j = 0; j < QK4_1; ++j) {
1661
- const float v = xi[j];
1662
-
1663
- if (v < vmin) vmin = v;
1664
- if (v > vmax) vmax = v;
1665
- }
1666
-
1667
- const float d = (vmax - vmin) / ((1 << 4) - 1);
1668
- const float id = d ? 1.0f/d : 0.0f;
1669
-
1670
- dsti->dm.x() = d;
1671
- dsti->dm.y() = vmin;
1672
-
1673
- for (int j = 0; j < QK4_1/2; ++j) {
1674
- const float x0 = (xi[0 + j] - vmin)*id;
1675
- const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
1676
-
1677
- const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
1678
- const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
1679
-
1680
- dsti->qs[j] = xi0;
1681
- dsti->qs[j] |= xi1 << 4;
1682
- }
1683
- }
1684
-
1685
- template <cpy_kernel_t cpy_blck, int qk>
1686
- static void cpy_f32_q(const char * cx, char * cdst, const int ne,
1687
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
1688
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
1689
- const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
1690
- const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1691
- item_ct1.get_local_id(2)) *
1692
- qk;
1693
-
1694
- if (i >= ne) {
1695
- return;
1696
- }
1697
-
1698
- const int i03 = i/(ne00 * ne01 * ne02);
1699
- const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
1700
- const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
1701
- const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
1702
- const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
1703
-
1704
- const int i13 = i/(ne10 * ne11 * ne12);
1705
- const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
1706
- const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
1707
- const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
1708
- const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
1709
-
1710
- cpy_blck(cx + x_offset, cdst + dst_offset);
1711
- }
1712
-
1713
1470
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
1714
1471
  const sycl::nd_item<3> &item_ct1) {
1715
1472
  const int row = item_ct1.get_group(1);
@@ -1895,81 +1652,6 @@ static void pool2d_nchw_kernel(
1895
1652
  o_ptr[cur_oh * ow + cur_ow] = res;
1896
1653
  }
1897
1654
 
1898
- template <int qk, int qr, dequantize_kernel_t dq>
1899
- static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
1900
- ggml_tensor *dst, const void *src0_dd,
1901
- const int32_t *src1_dd, float *dst_dd,
1902
- queue_ptr stream) {
1903
-
1904
- GGML_TENSOR_BINARY_OP_LOCALS
1905
-
1906
- const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
1907
- const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
1908
- const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
1909
-
1910
- // strides in elements
1911
- //const size_t s0 = nb0 / ggml_element_size(dst);
1912
- const size_t s1 = nb1 / ggml_element_size(dst);
1913
- const size_t s2 = nb2 / ggml_element_size(dst);
1914
- const size_t s3 = nb3 / ggml_element_size(dst);
1915
-
1916
- const size_t s10 = nb10 / ggml_element_size(src1);
1917
- const size_t s11 = nb11 / ggml_element_size(src1);
1918
- const size_t s12 = nb12 / ggml_element_size(src1);
1919
- //const size_t s13 = nb13 / ggml_element_size(src1);
1920
-
1921
- GGML_ASSERT(ne00 % 2 == 0);
1922
-
1923
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
1924
- [=](sycl::nd_item<3> item_ct1) {
1925
- k_get_rows<qk, qr, dq>(
1926
- src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
1927
- s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
1928
- });
1929
-
1930
- GGML_UNUSED(dst);
1931
- GGML_UNUSED(ctx);
1932
- }
1933
-
1934
- template <typename src0_t>
1935
- static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
1936
- const ggml_tensor *src1, ggml_tensor *dst,
1937
- const src0_t *src0_dd, const int32_t *src1_dd,
1938
- float *dst_dd, queue_ptr stream) {
1939
-
1940
- GGML_TENSOR_BINARY_OP_LOCALS
1941
-
1942
- const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
1943
- const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
1944
- const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
1945
-
1946
- // strides in elements
1947
- //const size_t s0 = nb0 / ggml_element_size(dst);
1948
- const size_t s1 = nb1 / ggml_element_size(dst);
1949
- const size_t s2 = nb2 / ggml_element_size(dst);
1950
- const size_t s3 = nb3 / ggml_element_size(dst);
1951
-
1952
- const size_t s10 = nb10 / ggml_element_size(src1);
1953
- const size_t s11 = nb11 / ggml_element_size(src1);
1954
- const size_t s12 = nb12 / ggml_element_size(src1);
1955
- //const size_t s13 = nb13 / ggml_element_size(src1);
1956
-
1957
- {
1958
- dpct::has_capability_or_fail(stream->get_device(),
1959
- {sycl::aspect::fp16});
1960
-
1961
- stream->parallel_for(
1962
- sycl::nd_range<3>(block_nums * block_dims, block_dims),
1963
- [=](sycl::nd_item<3> item_ct1) {
1964
- k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
1965
- s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
1966
- });
1967
- }
1968
-
1969
- GGML_UNUSED(dst);
1970
- GGML_UNUSED(ctx);
1971
- }
1972
-
1973
1655
  static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
1974
1656
  const int ky, const int kx_padded,
1975
1657
  queue_ptr stream) {
@@ -2033,231 +1715,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
2033
1715
  }
2034
1716
  }
2035
1717
 
2036
- static void
2037
- ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00,
2038
- const int ne01, const int ne02, const int nb00,
2039
- const int nb01, const int nb02, const int nb03,
2040
- const int ne10, const int ne11, const int ne12,
2041
- const int nb10, const int nb11, const int nb12,
2042
- const int nb13, queue_ptr stream) {
2043
-
2044
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2045
- {
2046
- dpct::has_capability_or_fail(stream->get_device(),
2047
- {sycl::aspect::fp16});
2048
1718
 
2049
- stream->parallel_for(
2050
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2051
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2052
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2053
- [=](sycl::nd_item<3> item_ct1) {
2054
- cpy_f32_f16<cpy_1_f16_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00,
2055
- nb01, nb02, nb03, ne10, ne11, ne12,
2056
- nb10, nb11, nb12, nb13, item_ct1);
2057
- });
2058
- }
2059
- }
2060
-
2061
- static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
2062
- const int ne00, const int ne01,
2063
- const int ne02, const int nb00,
2064
- const int nb01, const int nb02,
2065
- const int nb03, const int ne10,
2066
- const int ne11, const int ne12,
2067
- const int nb10, const int nb11,
2068
- const int nb12, const int nb13,
2069
- queue_ptr stream) {
2070
-
2071
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2072
- {
2073
- dpct::has_capability_or_fail(stream->get_device(),
2074
- {sycl::aspect::fp16});
2075
-
2076
- stream->parallel_for(
2077
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2078
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2079
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2080
- [=](sycl::nd_item<3> item_ct1) {
2081
- cpy_f32_f16<cpy_1_f32_f32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2082
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2083
- item_ct1);
2084
- });
2085
- }
2086
- }
2087
-
2088
- static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne,
2089
- const int ne00, const int ne01,
2090
- const int ne02, const int nb00,
2091
- const int nb01, const int nb02,
2092
- const int nb03, const int ne10,
2093
- const int ne11, const int ne12,
2094
- const int nb10, const int nb11,
2095
- const int nb12, const int nb13,
2096
- queue_ptr stream) {
2097
-
2098
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2099
- {
2100
- dpct::has_capability_or_fail(stream->get_device(),
2101
- {sycl::aspect::fp16});
2102
-
2103
- stream->parallel_for(
2104
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2105
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2106
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2107
- [=](sycl::nd_item<3> item_ct1) {
2108
- cpy_f32_f16<cpy_1_f32_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2109
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2110
- item_ct1);
2111
- });
2112
- }
2113
- }
2114
-
2115
- static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne,
2116
- const int ne00, const int ne01,
2117
- const int ne02, const int nb00,
2118
- const int nb01, const int nb02,
2119
- const int nb03, const int ne10,
2120
- const int ne11, const int ne12,
2121
- const int nb10, const int nb11,
2122
- const int nb12, const int nb13,
2123
- queue_ptr stream) {
2124
-
2125
- GGML_ASSERT(ne % QK8_0 == 0);
2126
- const int num_blocks = ne / QK8_0;
2127
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2128
- sycl::range<3>(1, 1, 1)),
2129
- [=](sycl::nd_item<3> item_ct1) {
2130
- cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(
2131
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2132
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2133
- item_ct1);
2134
- });
2135
- }
2136
-
2137
- static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne,
2138
- const int ne00, const int ne01,
2139
- const int ne02, const int nb00,
2140
- const int nb01, const int nb02,
2141
- const int nb03, const int ne10,
2142
- const int ne11, const int ne12,
2143
- const int nb10, const int nb11,
2144
- const int nb12, const int nb13,
2145
- queue_ptr stream) {
2146
-
2147
- GGML_ASSERT(ne % QK4_0 == 0);
2148
- const int num_blocks = ne / QK4_0;
2149
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2150
- sycl::range<3>(1, 1, 1)),
2151
- [=](sycl::nd_item<3> item_ct1) {
2152
- cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(
2153
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2154
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2155
- item_ct1);
2156
- });
2157
- }
2158
-
2159
- static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne,
2160
- const int ne00, const int ne01,
2161
- const int ne02, const int nb00,
2162
- const int nb01, const int nb02,
2163
- const int nb03, const int ne10,
2164
- const int ne11, const int ne12,
2165
- const int nb10, const int nb11,
2166
- const int nb12, const int nb13,
2167
- queue_ptr stream) {
2168
-
2169
- GGML_ASSERT(ne % QK4_1 == 0);
2170
- const int num_blocks = ne / QK4_1;
2171
- stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
2172
- sycl::range<3>(1, 1, 1)),
2173
- [=](sycl::nd_item<3> item_ct1) {
2174
- cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(
2175
- cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2176
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2177
- item_ct1);
2178
- });
2179
- }
2180
-
2181
- static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne,
2182
- const int ne00, const int ne01,
2183
- const int ne02, const int nb00,
2184
- const int nb01, const int nb02,
2185
- const int nb03, const int ne10,
2186
- const int ne11, const int ne12,
2187
- const int nb10, const int nb11,
2188
- const int nb12, const int nb13,
2189
- queue_ptr stream) {
2190
-
2191
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2192
- {
2193
- dpct::has_capability_or_fail(stream->get_device(),
2194
- {sycl::aspect::fp16});
2195
-
2196
- stream->parallel_for(
2197
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2198
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2199
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2200
- [=](sycl::nd_item<3> item_ct1) {
2201
- cpy_f32_f16<cpy_1_f16_f16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2202
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2203
- item_ct1);
2204
- });
2205
- }
2206
- }
2207
-
2208
- static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne,
2209
- const int ne00, const int ne01,
2210
- const int ne02, const int nb00,
2211
- const int nb01, const int nb02,
2212
- const int nb03, const int ne10,
2213
- const int ne11, const int ne12,
2214
- const int nb10, const int nb11,
2215
- const int nb12, const int nb13,
2216
- queue_ptr stream) {
2217
-
2218
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2219
- {
2220
- // dpct::has_capability_or_fail(stream->get_device(),
2221
- // {sycl::aspect::fp16});
2222
-
2223
- stream->parallel_for(
2224
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2225
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2226
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2227
- [=](sycl::nd_item<3> item_ct1) {
2228
- cpy_f32_f16<cpy_1_i16_i16>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2229
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2230
- item_ct1);
2231
- });
2232
- }
2233
- }
2234
-
2235
- static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne,
2236
- const int ne00, const int ne01,
2237
- const int ne02, const int nb00,
2238
- const int nb01, const int nb02,
2239
- const int nb03, const int ne10,
2240
- const int ne11, const int ne12,
2241
- const int nb10, const int nb11,
2242
- const int nb12, const int nb13,
2243
- queue_ptr stream) {
2244
-
2245
- const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
2246
- {
2247
- // dpct::has_capability_or_fail(stream->get_device(),
2248
- // {sycl::aspect::fp16});
2249
-
2250
- stream->parallel_for(
2251
- sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
2252
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
2253
- sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
2254
- [=](sycl::nd_item<3> item_ct1) {
2255
- cpy_f32_f16<cpy_1_i32_i32>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
2256
- nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
2257
- item_ct1);
2258
- });
2259
- }
2260
- }
2261
1719
 
2262
1720
  static void scale_f32_sycl(const float *x, float *dst, const float scale,
2263
1721
  const int k, queue_ptr stream) {
@@ -2493,52 +1951,6 @@ catch (sycl::exception const &exc) {
2493
1951
  std::exit(1);
2494
1952
  }
2495
1953
 
2496
- static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2497
- const ggml_tensor *src1, ggml_tensor *dst,
2498
- const float *src0_d, const float *src1_d,
2499
- float *dst_d, const queue_ptr &stream) {
2500
-
2501
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
2502
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
2503
-
2504
- GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
2505
- GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
2506
- GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
2507
-
2508
- const int32_t * src1_i32 = (const int32_t *) src1_d;
2509
-
2510
- switch (src0->type) {
2511
- case GGML_TYPE_F16:
2512
- get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
2513
- src1_i32, dst_d, stream);
2514
- break;
2515
- case GGML_TYPE_F32:
2516
- get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2517
- break;
2518
- case GGML_TYPE_Q4_0:
2519
- get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2520
- break;
2521
- case GGML_TYPE_Q4_1:
2522
- get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2523
- break;
2524
- case GGML_TYPE_Q5_0:
2525
- get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2526
- break;
2527
- case GGML_TYPE_Q5_1:
2528
- get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2529
- break;
2530
- case GGML_TYPE_Q8_0:
2531
- get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
2532
- break;
2533
- default:
2534
- // TODO: k-quants
2535
- GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
2536
- GGML_ABORT("fatal error");
2537
- break;
2538
- }
2539
- }
2540
-
2541
-
2542
1954
  static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
2543
1955
  const ggml_tensor *src1, ggml_tensor *dst,
2544
1956
  const float *src0_d, const float *src1_d,
@@ -2588,11 +2000,10 @@ inline void ggml_sycl_op_mul_mat_sycl(
2588
2000
  if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
2589
2001
  use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] &&
2590
2002
  dst->op_params[0] == GGML_PREC_DEFAULT) {
2591
-
2592
2003
  // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n");
2593
2004
  ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool());
2594
2005
  if (src0->type != GGML_TYPE_F16) {
2595
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type);
2006
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst);
2596
2007
  GGML_ASSERT(to_fp16_sycl != nullptr);
2597
2008
  size_t ne = row_diff*ne00;
2598
2009
  src0_as_f16.alloc(ne);
@@ -2604,7 +2015,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
2604
2015
 
2605
2016
  ggml_sycl_pool_alloc<sycl::half> src1_as_f16(ctx.pool());
2606
2017
  if (src1->type != GGML_TYPE_F16) {
2607
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
2018
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
2608
2019
  GGML_ASSERT(to_fp16_sycl != nullptr);
2609
2020
  size_t ne = src1_ncols*ne10;
2610
2021
  src1_as_f16.alloc(ne);
@@ -2625,13 +2036,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
2625
2036
  src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2626
2037
  dst_f16.get(), dpct::library_data_t::real_half, ldc,
2627
2038
  dpct::library_data_t::real_half)));
2628
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2039
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2629
2040
  to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2630
2041
  #else
2631
2042
  auto dnnl_stream = ctx.stream_dnnl(stream);
2632
2043
  DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2633
2044
  src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
2634
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
2045
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2635
2046
  to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2636
2047
  #endif
2637
2048
  }
@@ -2640,13 +2051,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
2640
2051
  ggml_sycl_pool_alloc<float> src0_ddq_as_f32(ctx.pool());
2641
2052
  ggml_sycl_pool_alloc<float> src1_ddq_as_f32(ctx.pool());
2642
2053
  if (src0->type != GGML_TYPE_F32) {
2643
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type);
2054
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst);
2644
2055
  GGML_ASSERT(to_fp32_sycl != nullptr);
2645
2056
  src0_ddq_as_f32.alloc(row_diff*ne00);
2646
2057
  to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
2647
2058
  }
2648
2059
  if (src1->type != GGML_TYPE_F32) {
2649
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type);
2060
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst);
2650
2061
  GGML_ASSERT(to_fp32_sycl != nullptr);
2651
2062
  src1_ddq_as_f32.alloc(src1_ncols*ne10);
2652
2063
  to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
@@ -3084,7 +2495,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
3084
2495
  for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
3085
2496
  const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
3086
2497
  const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
3087
-
3088
2498
  for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
3089
2499
  if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
3090
2500
  continue;
@@ -3392,7 +2802,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3392
2802
  // convert src1 to fp16
3393
2803
  ggml_sycl_pool_alloc<sycl::half> src1_f16_alloc(ctx.pool());
3394
2804
  if (src1->type != GGML_TYPE_F16) {
3395
- const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
2805
+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
3396
2806
  const int64_t ne_src1 = ggml_nelements(src1);
3397
2807
  src1_f16_alloc.alloc(ne_src1);
3398
2808
  GGML_ASSERT(to_fp16_sycl != nullptr);
@@ -3508,6 +2918,7 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
3508
2918
  }
3509
2919
 
3510
2920
  static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2921
+
3511
2922
  const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
3512
2923
  int64_t min_compute_capability = INT_MAX;
3513
2924
 
@@ -3569,6 +2980,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3569
2980
  ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
3570
2981
  } else if (use_dequantize_mul_mat_vec) {
3571
2982
  ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
2983
+ // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream());
3572
2984
  } else if (use_mul_mat_vec_q) {
3573
2985
  ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
3574
2986
  } else if (use_mul_mat_q) {
@@ -3821,58 +3233,6 @@ static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
3821
3233
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp);
3822
3234
  }
3823
3235
 
3824
- static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
3825
- ggml_tensor *dst) try {
3826
- const int64_t ne = ggml_nelements(src0);
3827
- GGML_ASSERT(ne == ggml_nelements(src1));
3828
-
3829
- GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
3830
- GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
3831
-
3832
- GGML_TENSOR_BINARY_OP_LOCALS01;
3833
-
3834
- SYCL_CHECK(ggml_sycl_set_device(ctx.device));
3835
- queue_ptr main_stream = ctx.stream();
3836
-
3837
- char * src0_ddc = (char *) src0->data;
3838
- char * src1_ddc = (char *) src1->data;
3839
-
3840
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
3841
- ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3842
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
3843
- ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3844
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
3845
- ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3846
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
3847
- ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3848
- } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
3849
- ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3850
- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
3851
- ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3852
- } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
3853
- ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3854
- } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
3855
- ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3856
- } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
3857
- ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
3858
- } else {
3859
- GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__,
3860
- ggml_type_name(src0->type), ggml_type_name(src1->type));
3861
- GGML_ABORT("fatal error");
3862
- }
3863
- GGML_UNUSED(dst);
3864
- }
3865
- catch (sycl::exception const &exc) {
3866
- std::cerr << exc.what() << "Exception caught at file:" << __FILE__
3867
- << ", line:" << __LINE__ << std::endl;
3868
- std::exit(1);
3869
- }
3870
-
3871
- static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3872
- // TODO: why do we pass dst as src1 here?
3873
- ggml_sycl_cpy(ctx, dst->src[0], dst, nullptr);
3874
- }
3875
-
3876
3236
  static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3877
3237
  ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
3878
3238
  }
@@ -4069,7 +3429,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
4069
3429
  ggml_sycl_clamp(ctx, dst);
4070
3430
  break;
4071
3431
  case GGML_OP_CPY:
4072
- ggml_sycl_cpy(ctx, dst->src[0], dst->src[1], dst);
3432
+ ggml_sycl_cpy(ctx, dst->src[0], dst->src[1]);
4073
3433
  break;
4074
3434
  case GGML_OP_CONT:
4075
3435
  ggml_sycl_dup(ctx, dst);
@@ -4250,10 +3610,72 @@ catch (sycl::exception const &exc) {
4250
3610
  std::exit(1);
4251
3611
  }
4252
3612
 
3613
+ void reorder_qw(char *data_device, const int ncols, const int nrows,
3614
+ size_t size, size_t offset, dpct::queue_ptr stream) {
3615
+ auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
3616
+ SYCL_CHECK(
3617
+ CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
3618
+ .wait()));
3619
+ GGML_ASSERT((size % sizeof(block_q4_0) == 0));
3620
+ GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
3621
+ int offset_blks = offset / sizeof(block_q4_0);
3622
+ auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;;
3623
+ auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
3624
+
3625
+ stream->parallel_for(
3626
+ size / sizeof(block_q4_0),
3627
+ [=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
3628
+ const block_q4_0* x = (const block_q4_0*)tmp_buf;
3629
+ const int ib = i;
3630
+
3631
+ for (int j = 0; j < QK4_0/2; j ++)
3632
+ {
3633
+ *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
3634
+ }
3635
+ *(d_ptr + ib) = x[ib].d;
3636
+ });
3637
+
3638
+ sycl::free(tmp_buf, *stream);
3639
+ }
3640
+
3641
+ void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) {
3642
+ char*data_device = (char*)src0->data;
3643
+ size_t ncols = src0->ne[0];
3644
+ size_t nrows = src0->ne[1];
3645
+ size_t size = ggml_nbytes(src0);
3646
+
3647
+ reorder_qw(data_device, ncols, nrows, size, 0, stream);
3648
+ }
3649
+
3650
+ void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) {
3651
+ ggml_tensor *src0 = dst->src[0];
3652
+ ggml_tensor *src1 = dst->src[1];
3653
+
3654
+ if (dst->op == GGML_OP_MUL_MAT && src0->type == GGML_TYPE_Q4_0 &&
3655
+ src1->ne[2]==1 && src1->ne[3]==1) {
3656
+ reorder_qw(src0, stream);
3657
+ ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
3658
+ GGML_ASSERT(extra);
3659
+ extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.
3660
+ }
3661
+ }
3662
+
3663
+ void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) {
3664
+ dpct::queue_ptr stream = ctx->stream();
3665
+ if (ctx->optimized_graph) {
3666
+ return;
3667
+ }
3668
+ ctx->optimized_graph = true;
3669
+
3670
+ for (int i = 0; i < cgraph->n_nodes; i++) {
3671
+ if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream);
3672
+ }
3673
+ }
4253
3674
  static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
4254
3675
  ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
4255
3676
  ggml_sycl_set_main_device(sycl_ctx->device);
4256
3677
 
3678
+ if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx);
4257
3679
 
4258
3680
  for (int i = 0; i < cgraph->n_nodes; i++) {
4259
3681
  ggml_tensor * node = cgraph->nodes[i];
@@ -4443,7 +3865,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4443
3865
  case GGML_UNARY_OP_GELU_QUICK:
4444
3866
  case GGML_UNARY_OP_TANH:
4445
3867
  case GGML_UNARY_OP_EXP:
4446
- return ggml_is_contiguous(op->src[0]);
3868
+ return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32);
4447
3869
  default:
4448
3870
  return false;
4449
3871
  }
@@ -4521,6 +3943,30 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4521
3943
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
4522
3944
  return true;
4523
3945
  }
3946
+ if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
3947
+ return true;
3948
+ }
3949
+ if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
3950
+ return true;
3951
+ }
3952
+ if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
3953
+ return true;
3954
+ }
3955
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
3956
+ return true;
3957
+ }
3958
+ if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
3959
+ return true;
3960
+ }
3961
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
3962
+ return true;
3963
+ }
3964
+ if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
3965
+ return true;
3966
+ }
3967
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
3968
+ return true;
3969
+ }
4524
3970
  return false;
4525
3971
  } break;
4526
3972
  case GGML_OP_CONCAT:
@@ -4536,23 +3982,24 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4536
3982
  case GGML_OP_VIEW:
4537
3983
  case GGML_OP_PERMUTE:
4538
3984
  case GGML_OP_TRANSPOSE:
3985
+ return true;
4539
3986
  case GGML_OP_ADD:
4540
3987
  case GGML_OP_ADD1:
4541
- case GGML_OP_LOG:
4542
3988
  case GGML_OP_SUB:
4543
3989
  case GGML_OP_MUL:
4544
3990
  case GGML_OP_DIV:
4545
- return true;
4546
- case GGML_OP_NORM:
4547
- case GGML_OP_RMS_NORM:
4548
- case GGML_OP_GROUP_NORM:
4549
- return ggml_is_contiguous(op->src[0]);
4550
- case GGML_OP_SCALE:
4551
3991
  case GGML_OP_SQR:
4552
3992
  case GGML_OP_SQRT:
4553
3993
  case GGML_OP_SIN:
4554
3994
  case GGML_OP_COS:
4555
3995
  case GGML_OP_CLAMP:
3996
+ case GGML_OP_LOG:
3997
+ return (op->src[0]->type == GGML_TYPE_F32);
3998
+ case GGML_OP_NORM:
3999
+ case GGML_OP_RMS_NORM:
4000
+ case GGML_OP_GROUP_NORM:
4001
+ return ggml_is_contiguous(op->src[0]);
4002
+ case GGML_OP_SCALE:
4556
4003
  return true;
4557
4004
  case GGML_OP_CONT:
4558
4005
  return op->src[0]->type != GGML_TYPE_BF16;