@fugood/llama.node 0.3.6 → 0.3.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/README.md +17 -2
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +3 -1
  19. package/lib/index.js +16 -1
  20. package/lib/index.ts +16 -0
  21. package/package.json +1 -1
  22. package/src/EmbeddingWorker.cpp +4 -3
  23. package/src/LlamaCompletionWorker.cpp +4 -2
  24. package/src/LlamaContext.cpp +61 -6
  25. package/src/LlamaContext.h +1 -0
  26. package/src/common.hpp +6 -11
  27. package/src/llama.cpp/.github/workflows/build.yml +19 -17
  28. package/src/llama.cpp/.github/workflows/docker.yml +77 -30
  29. package/src/llama.cpp/.github/workflows/editorconfig.yml +3 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +22 -3
  31. package/src/llama.cpp/CMakeLists.txt +49 -24
  32. package/src/llama.cpp/common/arg.cpp +82 -26
  33. package/src/llama.cpp/common/arg.h +3 -0
  34. package/src/llama.cpp/common/common.cpp +192 -72
  35. package/src/llama.cpp/common/common.h +51 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +12 -12
  37. package/src/llama.cpp/common/ngram-cache.h +2 -2
  38. package/src/llama.cpp/common/sampling.cpp +11 -6
  39. package/src/llama.cpp/common/speculative.cpp +18 -15
  40. package/src/llama.cpp/docs/build.md +2 -0
  41. package/src/llama.cpp/examples/batched/batched.cpp +9 -7
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +3 -3
  43. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +10 -8
  44. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +11 -8
  45. package/src/llama.cpp/examples/cvector-generator/mean.hpp +1 -1
  46. package/src/llama.cpp/examples/cvector-generator/pca.hpp +1 -1
  47. package/src/llama.cpp/examples/embedding/embedding.cpp +8 -7
  48. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +7 -6
  49. package/src/llama.cpp/examples/export-lora/export-lora.cpp +8 -7
  50. package/src/llama.cpp/examples/gguf/gguf.cpp +10 -6
  51. package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +1 -0
  52. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +8 -7
  53. package/src/llama.cpp/examples/gritlm/gritlm.cpp +13 -10
  54. package/src/llama.cpp/examples/imatrix/imatrix.cpp +13 -12
  55. package/src/llama.cpp/examples/infill/infill.cpp +23 -24
  56. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +44 -13
  57. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -6
  58. package/src/llama.cpp/examples/llava/clip.cpp +4 -2
  59. package/src/llama.cpp/examples/llava/llava-cli.cpp +9 -6
  60. package/src/llama.cpp/examples/llava/llava.cpp +2 -2
  61. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +8 -4
  62. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +11 -8
  63. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -7
  64. package/src/llama.cpp/examples/lookup/lookup-create.cpp +4 -9
  65. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +3 -7
  66. package/src/llama.cpp/examples/lookup/lookup.cpp +5 -6
  67. package/src/llama.cpp/examples/main/main.cpp +51 -29
  68. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -6
  69. package/src/llama.cpp/examples/passkey/passkey.cpp +7 -5
  70. package/src/llama.cpp/examples/perplexity/perplexity.cpp +37 -23
  71. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -14
  72. package/src/llama.cpp/examples/retrieval/retrieval.cpp +8 -8
  73. package/src/llama.cpp/examples/rpc/rpc-server.cpp +12 -0
  74. package/src/llama.cpp/examples/run/CMakeLists.txt +1 -1
  75. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +1351 -0
  76. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +114 -0
  77. package/src/llama.cpp/examples/run/run.cpp +175 -61
  78. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -25
  79. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -0
  80. package/src/llama.cpp/examples/server/httplib.h +1295 -409
  81. package/src/llama.cpp/examples/server/server.cpp +387 -181
  82. package/src/llama.cpp/examples/server/tests/requirements.txt +1 -0
  83. package/src/llama.cpp/examples/server/utils.hpp +170 -58
  84. package/src/llama.cpp/examples/simple/simple.cpp +9 -8
  85. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +16 -12
  86. package/src/llama.cpp/examples/speculative/speculative.cpp +22 -23
  87. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +8 -12
  88. package/src/llama.cpp/examples/tokenize/tokenize.cpp +17 -5
  89. package/src/llama.cpp/examples/tts/tts.cpp +64 -23
  90. package/src/llama.cpp/ggml/CMakeLists.txt +5 -21
  91. package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
  92. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -0
  93. package/src/llama.cpp/ggml/include/ggml.h +36 -145
  94. package/src/llama.cpp/ggml/include/gguf.h +202 -0
  95. package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
  96. package/src/llama.cpp/ggml/src/ggml-alloc.c +5 -0
  97. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +0 -1
  98. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +79 -49
  99. package/src/llama.cpp/ggml/src/ggml-backend.cpp +5 -2
  100. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +33 -23
  101. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +57 -72
  102. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +87 -2
  103. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +335 -66
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -2
  105. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1090 -378
  106. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +2 -2
  107. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +1 -0
  108. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
  109. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  110. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +3 -1
  111. package/src/llama.cpp/ggml/src/ggml-impl.h +11 -16
  112. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +16 -0
  113. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +6 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +154 -35
  115. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  116. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +9 -3
  117. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +18 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  119. package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +1 -2
  120. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +3 -2
  121. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +1 -2
  122. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
  123. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +48 -48
  124. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +24 -24
  125. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -164
  126. package/src/llama.cpp/ggml/src/ggml-sycl/gla.cpp +105 -0
  127. package/src/llama.cpp/ggml/src/ggml-sycl/gla.hpp +8 -0
  128. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +3 -3
  129. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +1 -2
  130. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -2
  131. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +1 -2
  132. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +7 -5
  133. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +1 -2
  134. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +74 -4
  135. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +314 -116
  136. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -2
  137. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +9 -3
  138. package/src/llama.cpp/ggml/src/ggml.c +117 -1327
  139. package/src/llama.cpp/ggml/src/gguf.cpp +1329 -0
  140. package/src/llama.cpp/include/llama-cpp.h +6 -1
  141. package/src/llama.cpp/include/llama.h +138 -75
  142. package/src/llama.cpp/src/CMakeLists.txt +13 -1
  143. package/src/llama.cpp/src/llama-adapter.cpp +347 -0
  144. package/src/llama.cpp/src/llama-adapter.h +74 -0
  145. package/src/llama.cpp/src/llama-arch.cpp +1487 -0
  146. package/src/llama.cpp/src/llama-arch.h +400 -0
  147. package/src/llama.cpp/src/llama-batch.cpp +368 -0
  148. package/src/llama.cpp/src/llama-batch.h +88 -0
  149. package/src/llama.cpp/src/llama-chat.cpp +578 -0
  150. package/src/llama.cpp/src/llama-chat.h +52 -0
  151. package/src/llama.cpp/src/llama-context.cpp +1775 -0
  152. package/src/llama.cpp/src/llama-context.h +128 -0
  153. package/src/llama.cpp/src/llama-cparams.cpp +1 -0
  154. package/src/llama.cpp/src/llama-cparams.h +37 -0
  155. package/src/llama.cpp/src/llama-grammar.cpp +5 -4
  156. package/src/llama.cpp/src/llama-grammar.h +3 -1
  157. package/src/llama.cpp/src/llama-hparams.cpp +71 -0
  158. package/src/llama.cpp/src/llama-hparams.h +139 -0
  159. package/src/llama.cpp/src/llama-impl.cpp +167 -0
  160. package/src/llama.cpp/src/llama-impl.h +16 -136
  161. package/src/llama.cpp/src/llama-kv-cache.cpp +718 -0
  162. package/src/llama.cpp/src/llama-kv-cache.h +218 -0
  163. package/src/llama.cpp/src/llama-mmap.cpp +589 -0
  164. package/src/llama.cpp/src/llama-mmap.h +67 -0
  165. package/src/llama.cpp/src/llama-model-loader.cpp +1124 -0
  166. package/src/llama.cpp/src/llama-model-loader.h +167 -0
  167. package/src/llama.cpp/src/llama-model.cpp +3953 -0
  168. package/src/llama.cpp/src/llama-model.h +370 -0
  169. package/src/llama.cpp/src/llama-quant.cpp +934 -0
  170. package/src/llama.cpp/src/llama-quant.h +1 -0
  171. package/src/llama.cpp/src/llama-sampling.cpp +147 -32
  172. package/src/llama.cpp/src/llama-sampling.h +3 -19
  173. package/src/llama.cpp/src/llama-vocab.cpp +1832 -575
  174. package/src/llama.cpp/src/llama-vocab.h +97 -142
  175. package/src/llama.cpp/src/llama.cpp +7160 -20314
  176. package/src/llama.cpp/src/unicode.cpp +8 -3
  177. package/src/llama.cpp/tests/CMakeLists.txt +2 -0
  178. package/src/llama.cpp/tests/test-autorelease.cpp +3 -3
  179. package/src/llama.cpp/tests/test-backend-ops.cpp +370 -59
  180. package/src/llama.cpp/tests/test-chat-template.cpp +162 -125
  181. package/src/llama.cpp/tests/test-gguf.cpp +222 -187
  182. package/src/llama.cpp/tests/test-model-load-cancel.cpp +1 -1
  183. package/src/llama.cpp/tests/test-sampling.cpp +0 -1
  184. package/src/llama.cpp/tests/test-tokenizer-0.cpp +4 -4
  185. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +9 -7
  186. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +8 -6
@@ -54,18 +54,12 @@ static ggml_sycl_device_info ggml_sycl_init() {
54
54
  GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
55
55
 
56
56
  int64_t total_vram = 0;
57
- #if defined(GGML_SYCL_FORCE_MMQ)
58
- GGML_LOG_INFO("%s: GGML_SYCL_FORCE_MMQ: yes\n", __func__);
59
- #else
60
- GGML_LOG_INFO("%s: GGML_SYCL_FORCE_MMQ: no\n", __func__);
61
- #endif
62
- #if defined(SYCL_USE_XMX)
63
- GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
64
- #else
65
- GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
66
- #endif
67
- GGML_LOG_INFO("%s: found %d %s devices:\n", __func__, info.device_count, GGML_SYCL_NAME);
68
-
57
+ /* This is a bit misleading; reserved for later */
58
+ // #if defined(SYCL_USE_XMX)
59
+ // GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
60
+ // #else
61
+ // GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
62
+ // #endif
69
63
  for (int i = 0; i < info.device_count; ++i) {
70
64
  info.devices[i].vmm = 0;
71
65
  dpct::device_info prop;
@@ -109,11 +103,11 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
109
103
  name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
110
104
 
111
105
  auto global_mem_size = prop.get_global_mem_size()/1000000;
112
-
113
- GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
106
+ std::string xmx = gpu_has_xmx(device) ? "yes" : "no";
107
+ GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|%14s|\n", id, device_type.c_str(),
114
108
  name.c_str(), version.c_str(), prop.get_max_compute_units(),
115
109
  prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
116
- global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
110
+ global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str(), xmx.c_str());
117
111
  }
118
112
 
119
113
  void ggml_backend_sycl_print_sycl_devices() {
@@ -124,16 +118,16 @@ void ggml_backend_sycl_print_sycl_devices() {
124
118
 
125
119
  GGML_LOG_INFO(
126
120
  "| | | | "
127
- " |Max | |Max |Global | |\n");
121
+ " |Max | |Max |Global | | XMX |\n");
128
122
  GGML_LOG_INFO(
129
123
  "| | | | "
130
- " |compute|Max work|sub |mem | |\n");
124
+ " |compute|Max work|sub |mem | | or |\n");
131
125
  GGML_LOG_INFO(
132
126
  "|ID| Device Type| "
133
- "Name|Version|units |group |group|size | Driver version|\n");
127
+ "Name|Version|units |group |group|size | Driver version| Tensor Cores |\n");
134
128
  GGML_LOG_INFO(
135
129
  "|--|-------------------|---------------------------------------|------"
136
- "-|-------|--------|-----|-------|---------------------|\n");
130
+ "-|-------|--------|-----|-------|---------------------|--------------|\n");
137
131
 
138
132
  for (int id = 0; id < device_count; ++id) {
139
133
  sycl::device device = dpct::dev_mgr::instance().get_device(id);
@@ -164,14 +158,18 @@ static void ggml_check_sycl() try {
164
158
  static bool initialized = false;
165
159
 
166
160
  if (!initialized) {
167
- GGML_LOG_INFO("[SYCL] call ggml_check_sycl\n");
161
+ GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
168
162
  g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
169
- GGML_LOG_INFO("%s: GGML_SYCL_DEBUG: %d\n", __func__, g_ggml_sycl_debug);
170
-
163
+ GGML_LOG_INFO("GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
164
+ #if defined(GGML_SYCL_FORCE_MMQ)
165
+ GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: yes\n");
166
+ #else
167
+ GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: no\n");
168
+ #endif
171
169
  #if defined(GGML_SYCL_F16)
172
- GGML_LOG_INFO("%s: GGML_SYCL_F16: yes\n", __func__);
170
+ GGML_LOG_INFO("GGML_SYCL_F16: yes\n");
173
171
  #else
174
- GGML_LOG_INFO("%s: GGML_SYCL_F16: no\n", __func__);
172
+ GGML_LOG_INFO("GGML_SYCL_F16: no\n");
175
173
  #endif
176
174
 
177
175
  /* NOT REMOVE, keep it for next optimize for XMX.
@@ -288,10 +286,8 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
288
286
  ggml_tensor *tensor) try {
289
287
  ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
290
288
 
291
- if (tensor->view_src != NULL && tensor->view_offs == 0) {
289
+ if (tensor->view_src != NULL) {
292
290
  assert(tensor->view_src->buffer->buft == buffer->buft);
293
- tensor->backend = tensor->view_src->backend;
294
- tensor->extra = tensor->view_src->extra;
295
291
  return;
296
292
  }
297
293
 
@@ -539,7 +535,7 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
539
535
  auto dev_count = ggml_backend_sycl_get_device_count();
540
536
 
541
537
  if (device>=dev_count or device<0) {
542
- printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
538
+ GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
543
539
  device, dev_count-1);
544
540
  GGML_ASSERT(device<dev_count);
545
541
  }
@@ -567,7 +563,7 @@ ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_conte
567
563
 
568
564
  int device = ctx->device;
569
565
  if (device>=ggml_sycl_info().device_count or device<0) {
570
- printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
566
+ GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
571
567
  device, ggml_sycl_info().device_count-1);
572
568
  GGML_ASSERT(device<ggml_sycl_info().device_count);
573
569
  }
@@ -746,7 +742,7 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
746
742
  size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
747
743
  }
748
744
 
749
- // FIXME: do not crash if cudaMalloc fails
745
+ // FIXME: do not crash if SYCL Buffer alloc fails
750
746
  // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
751
747
  ggml_sycl_set_device(i);
752
748
  const queue_ptr stream = ctx->streams[i];
@@ -788,7 +784,6 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
788
784
  CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event()));
789
785
  }
790
786
  }
791
- tensor->backend = GGML_BACKEND_TYPE_GPU_SPLIT;
792
787
  tensor->extra = extra;
793
788
  }
794
789
  catch (sycl::exception const &exc) {
@@ -1178,6 +1173,85 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool {
1178
1173
  }
1179
1174
  };
1180
1175
 
1176
+ struct ggml_sycl_pool_host : public ggml_sycl_pool {
1177
+ queue_ptr qptr;
1178
+ int device;
1179
+
1180
+ inline static int counter{ 0 };
1181
+
1182
+ struct ggml_sycl_buffer {
1183
+ void * ptr = nullptr;
1184
+ size_t size = 0;
1185
+ };
1186
+
1187
+ // Set arbitrarly to 64
1188
+ static constexpr int MAX_POOL_SIZE{ 64 };
1189
+ std::vector<ggml_sycl_buffer> buffer_pool = std::vector<ggml_sycl_buffer>(MAX_POOL_SIZE);
1190
+ size_t pool_size = 0;
1191
+
1192
+ explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {}
1193
+
1194
+ ~ggml_sycl_pool_host() {
1195
+ for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1196
+ ggml_sycl_buffer & b = buffer_pool[i];
1197
+ if (b.ptr != nullptr) {
1198
+ SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
1199
+ b.ptr = nullptr;
1200
+ pool_size -= b.size;
1201
+ b.size = 0;
1202
+ }
1203
+ }
1204
+ counter = 0;
1205
+ }
1206
+
1207
+ void * alloc(size_t size, size_t * actual_size) override {
1208
+ if (counter == MAX_POOL_SIZE) {
1209
+ ggml_sycl_buffer b = buffer_pool[0];
1210
+ void * ptr = b.ptr;
1211
+ *actual_size = b.size;
1212
+ counter = 1;
1213
+ return ptr;
1214
+ }
1215
+ ggml_sycl_buffer & b = buffer_pool[counter];
1216
+
1217
+ if (b.ptr == nullptr) {
1218
+ void * ptr;
1219
+
1220
+ SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr)));
1221
+ if (!ptr) {
1222
+ GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size);
1223
+ return nullptr;
1224
+ }
1225
+ pool_size += size;
1226
+ *actual_size = size;
1227
+ counter = counter + 1;
1228
+ return ptr;
1229
+ } else {
1230
+ ++counter;
1231
+ b.size = size;
1232
+ return b.ptr;
1233
+ }
1234
+ }
1235
+
1236
+ void free(void * ptr, size_t size) override {
1237
+ // if the pool is not completed add the pointer to it in place of the first nullptr found.
1238
+ // Otherwise do nothing, pointers will be freed once the pool is deallocated.
1239
+ for (int i = 0; i < MAX_POOL_SIZE; ++i) {
1240
+ ggml_sycl_buffer & b = buffer_pool[i];
1241
+ if (b.ptr == nullptr) {
1242
+ b.ptr = ptr;
1243
+ b.size = size;
1244
+ return;
1245
+ }
1246
+ }
1247
+ }
1248
+ };
1249
+
1250
+ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) {
1251
+ // return pool for the host to speed up memory management
1252
+ return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_host(qptr, device));
1253
+ }
1254
+
1181
1255
  std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
1182
1256
  // TBD: NO VMM support
1183
1257
  // if (ggml_sycl_info().devices[device].vmm) {
@@ -1192,7 +1266,6 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(q
1192
1266
  /// kernels
1193
1267
 
1194
1268
  typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
1195
- typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
1196
1269
  typedef void (*ggml_sycl_op_mul_mat_t)(
1197
1270
  ggml_backend_sycl_context & ctx,
1198
1271
  const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -2349,12 +2422,22 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
2349
2422
 
2350
2423
  dpct::memcpy_direction kind;
2351
2424
  char * src_ptr;
2352
- if (src->backend == GGML_BACKEND_TYPE_CPU) {
2425
+ if (ggml_backend_buffer_is_host(src->buffer)) {
2353
2426
  kind = dpct::host_to_device;
2427
+ //GGML_SYCL_DEBUG("%s: Host buffer type src tensor\n", __func__);
2354
2428
  src_ptr = (char *) src->data;
2355
2429
  // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
2356
- } else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
2357
- GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
2430
+ } else if (ggml_backend_buffer_is_sycl(src->buffer)) {
2431
+ // If buffer is a SYCL buffer
2432
+ //GGML_SYCL_DEBUG("%s: SYCL buffer type src tensor\n", __func__);
2433
+ kind = dpct::device_to_device;
2434
+ src_ptr = (char *) src->data;
2435
+ } else if (ggml_backend_buffer_is_sycl_split(src->buffer)) {
2436
+ /*
2437
+ If buffer is a SYCL split buffer
2438
+ */
2439
+ //GGML_SYCL_DEBUG("%s: Split buffer type src tensor\n", __func__);
2440
+ GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]);
2358
2441
  kind = dpct::device_to_device;
2359
2442
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
2360
2443
  int id;
@@ -2857,8 +2940,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2857
2940
  const int nb2 = dst->nb[2];
2858
2941
  const int nb3 = dst->nb[3];
2859
2942
 
2860
- GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2861
- GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2943
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer));
2944
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src1->buffer));
2862
2945
  GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
2863
2946
 
2864
2947
  GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
@@ -2878,7 +2961,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2878
2961
 
2879
2962
  int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
2880
2963
 
2881
- const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
2964
+ const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
2882
2965
  GGML_ASSERT(!(split && ne02 > 1));
2883
2966
  GGML_ASSERT(!(split && ne03 > 1));
2884
2967
  GGML_ASSERT(!(split && ne02 < ne12));
@@ -3164,33 +3247,33 @@ catch (sycl::exception const &exc) {
3164
3247
  }
3165
3248
 
3166
3249
 
3167
- static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3250
+ static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3168
3251
  GGML_SYCL_DEBUG("call %s\n", __func__);
3169
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_repeat);
3252
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_repeat);
3170
3253
  GGML_SYCL_DEBUG("call %s done\n", __func__);
3171
3254
  }
3172
3255
 
3173
- static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3256
+ static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3174
3257
  GGML_SYCL_DEBUG("call %s\n", __func__);
3175
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_get_rows);
3258
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_get_rows);
3176
3259
  GGML_SYCL_DEBUG("call %s done\n", __func__);
3177
3260
  }
3178
3261
 
3179
- static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3262
+ static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3180
3263
  GGML_SYCL_DEBUG("call %s\n", __func__);
3181
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
3264
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_norm);
3182
3265
  GGML_SYCL_DEBUG("call %s done\n", __func__);
3183
3266
  }
3184
3267
 
3185
- static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3268
+ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3186
3269
  GGML_SYCL_DEBUG("call %s\n", __func__);
3187
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
3270
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rms_norm);
3188
3271
  GGML_SYCL_DEBUG("call %s done\n", __func__);
3189
3272
  }
3190
3273
 
3191
- static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3274
+ static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3192
3275
  GGML_SYCL_DEBUG("call %s\n", __func__);
3193
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
3276
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
3194
3277
  GGML_SYCL_DEBUG("call %s done\n", __func__);
3195
3278
  }
3196
3279
 
@@ -3198,7 +3281,7 @@ static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const gg
3198
3281
  const ggml_tensor *src1,
3199
3282
  ggml_tensor *dst) try {
3200
3283
  GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
3201
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
3284
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
3202
3285
  GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
3203
3286
  GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
3204
3287
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -3231,7 +3314,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
3231
3314
  GGML_ASSERT(!ggml_is_transposed(src0));
3232
3315
  GGML_ASSERT(!ggml_is_transposed(src1));
3233
3316
  GGML_ASSERT(!ggml_is_permuted(src0));
3234
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
3317
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
3235
3318
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
3236
3319
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
3237
3320
 
@@ -3293,7 +3376,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3293
3376
  ggml_tensor *dst) try {
3294
3377
  GGML_ASSERT(!ggml_is_transposed(src0));
3295
3378
  GGML_ASSERT(!ggml_is_transposed(src1));
3296
- GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
3379
+ GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
3297
3380
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
3298
3381
 
3299
3382
  GGML_TENSOR_BINARY_OP_LOCALS
@@ -3359,6 +3442,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3359
3442
 
3360
3443
  ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
3361
3444
  ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
3445
+ ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
3362
3446
 
3363
3447
  sycl::range<3> block_dims(1, ne12, ne13);
3364
3448
  /*
@@ -3387,14 +3471,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
3387
3471
  });
3388
3472
  }
3389
3473
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
3390
- *main_stream, oneapi::mkl::transpose::trans,
3391
- oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3392
- (const void **)(ptrs_src.get() + 0 * ne23),
3393
- dpct::library_data_t::real_half, nb01 / nb00,
3394
- (const void **)(ptrs_src.get() + 1 * ne23),
3395
- dpct::library_data_t::real_half, nb11 / nb10, beta,
3396
- (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
3397
- cu_compute_type)));
3474
+ *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
3475
+ (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
3476
+ (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
3477
+ (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
3398
3478
  }
3399
3479
  }
3400
3480
  catch (sycl::exception const &exc) {
@@ -3565,9 +3645,10 @@ __dpct_inline__ static void k_copy_dst_from_contiguous(
3565
3645
  }
3566
3646
  }
3567
3647
 
3568
- static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
3569
- const ggml_tensor *src1,
3648
+ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
3570
3649
  ggml_tensor *dst) try {
3650
+ const ggml_tensor *src0 = dst->src[0];
3651
+ const ggml_tensor *src1 = dst->src[1];
3571
3652
  GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
3572
3653
 
3573
3654
  const ggml_tensor *ids = dst->src[2];
@@ -3733,12 +3814,12 @@ catch (sycl::exception const &exc) {
3733
3814
  std::exit(1);
3734
3815
  }
3735
3816
 
3736
- static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3737
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_scale);
3817
+ static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3818
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_scale);
3738
3819
  }
3739
3820
 
3740
- static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3741
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_clamp);
3821
+ static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3822
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp);
3742
3823
  }
3743
3824
 
3744
3825
  static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
@@ -3780,7 +3861,6 @@ static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
3780
3861
  ggml_type_name(src0->type), ggml_type_name(src1->type));
3781
3862
  GGML_ABORT("fatal error");
3782
3863
  }
3783
-
3784
3864
  GGML_UNUSED(dst);
3785
3865
  }
3786
3866
  catch (sycl::exception const &exc) {
@@ -3789,59 +3869,52 @@ catch (sycl::exception const &exc) {
3789
3869
  std::exit(1);
3790
3870
  }
3791
3871
 
3792
- static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3872
+ static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3793
3873
  // TODO: why do we pass dst as src1 here?
3794
- ggml_sycl_cpy(ctx, src0, dst, nullptr);
3795
- GGML_UNUSED(src1);
3874
+ ggml_sycl_cpy(ctx, dst->src[0], dst, nullptr);
3796
3875
  }
3797
3876
 
3798
- static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3799
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_diag_mask_inf);
3877
+ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3878
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
3800
3879
  }
3801
3880
 
3802
- static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3803
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_soft_max);
3881
+ static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3882
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max);
3804
3883
  }
3805
3884
 
3806
- static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3807
- GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
3808
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rope);
3885
+ static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3886
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
3887
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
3809
3888
  }
3810
3889
 
3811
- static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3812
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pool2d);
3890
+ static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3891
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pool2d);
3813
3892
  }
3814
3893
 
3815
- static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3816
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
3894
+ static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3895
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_im2col);
3817
3896
  }
3818
3897
 
3819
- static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3820
- GGML_ASSERT(ggml_is_contiguous(src0));
3821
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum);
3898
+ static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3899
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3900
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum);
3822
3901
  }
3823
3902
 
3824
- static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3825
- GGML_ASSERT(ggml_is_contiguous(src0));
3826
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
3903
+ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3904
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3905
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum_rows);
3827
3906
  }
3828
3907
 
3829
- static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3830
- GGML_ASSERT(ggml_is_contiguous(src0));
3831
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
3908
+ static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3909
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3910
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argsort);
3832
3911
  }
3833
3912
 
3834
- static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3835
- GGML_ASSERT(ggml_is_contiguous(src0));
3836
- ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argmax);
3913
+ static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3914
+ GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3915
+ ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argmax);
3837
3916
  }
3838
3917
 
3839
- static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
3840
- GGML_UNUSED(src0);
3841
- GGML_UNUSED(src1);
3842
- GGML_UNUSED(dst);
3843
- GGML_UNUSED(ctx);
3844
- }
3845
3918
 
3846
3919
  void ggml_sycl_set_main_device(const int main_device) try {
3847
3920
  if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
@@ -3864,191 +3937,192 @@ catch (sycl::exception const &exc) {
3864
3937
  std::exit(1);
3865
3938
  }
3866
3939
 
3867
- bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * tensor) {
3940
+ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
3868
3941
  if (!g_sycl_loaded) return false;
3869
3942
 
3870
- ggml_sycl_func_t func;
3943
+ if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
3944
+ ggml_sycl_set_peer_access(dst->src[1]->ne[1], ctx.device);
3945
+ }
3871
3946
 
3872
- switch (tensor->op) {
3947
+ switch (dst->op) {
3873
3948
  case GGML_OP_ARGMAX:
3874
- func = ggml_sycl_argmax;
3949
+ ggml_sycl_argmax(ctx, dst);
3875
3950
  break;
3876
3951
  case GGML_OP_CONV_TRANSPOSE_1D:
3877
- func = ggml_sycl_op_conv_transpose_1d;
3952
+ ggml_sycl_op_conv_transpose_1d(ctx, dst);
3878
3953
  break;
3879
3954
  case GGML_OP_REPEAT:
3880
- func = ggml_sycl_repeat;
3955
+ ggml_sycl_repeat(ctx, dst);
3881
3956
  break;
3882
3957
  case GGML_OP_GET_ROWS:
3883
- func = ggml_sycl_get_rows;
3958
+ ggml_sycl_get_rows(ctx, dst);
3884
3959
  break;
3885
3960
  case GGML_OP_DUP:
3886
- func = ggml_sycl_dup;
3961
+ ggml_sycl_dup(ctx, dst);
3887
3962
  break;
3888
3963
  case GGML_OP_ADD:
3889
3964
  case GGML_OP_ADD1: // TODO: more efficient implementation
3890
- func = ggml_sycl_add;
3965
+ ggml_sycl_add(ctx, dst);
3891
3966
  break;
3892
3967
  case GGML_OP_SUB:
3893
- func = ggml_sycl_sub;
3968
+ ggml_sycl_sub(ctx, dst);
3894
3969
  break;
3895
3970
  case GGML_OP_ACC:
3896
- func = ggml_sycl_acc;
3971
+ ggml_sycl_acc(ctx, dst);
3897
3972
  break;
3898
3973
  case GGML_OP_MUL:
3899
- func = ggml_sycl_mul;
3974
+ ggml_sycl_mul(ctx, dst);
3900
3975
  break;
3901
3976
  case GGML_OP_LOG:
3902
- func = ggml_sycl_log;
3977
+ ggml_sycl_log(ctx, dst);
3903
3978
  break;
3904
3979
  case GGML_OP_DIV:
3905
- func = ggml_sycl_div;
3980
+ ggml_sycl_div(ctx, dst);
3906
3981
  break;
3907
3982
  case GGML_OP_UNARY:
3908
- switch (ggml_get_unary_op(tensor)) {
3983
+ switch (ggml_get_unary_op(dst)) {
3909
3984
  case GGML_UNARY_OP_NEG:
3910
- func = ggml_sycl_neg;
3985
+ ggml_sycl_neg(ctx, dst);
3911
3986
  break;
3912
3987
  case GGML_UNARY_OP_STEP:
3913
- func = ggml_sycl_step;
3988
+ ggml_sycl_step(ctx, dst);
3914
3989
  break;
3915
3990
  case GGML_UNARY_OP_GELU:
3916
- func = ggml_sycl_gelu;
3991
+ ggml_sycl_gelu(ctx, dst);
3917
3992
  break;
3918
3993
  case GGML_UNARY_OP_SILU:
3919
- func = ggml_sycl_silu;
3994
+ ggml_sycl_silu(ctx, dst);
3920
3995
  break;
3921
3996
  case GGML_UNARY_OP_GELU_QUICK:
3922
- func = ggml_sycl_gelu_quick;
3997
+ ggml_sycl_gelu_quick(ctx, dst);
3923
3998
  break;
3924
3999
  case GGML_UNARY_OP_TANH:
3925
- func = ggml_sycl_tanh;
4000
+ ggml_sycl_tanh(ctx, dst);
3926
4001
  break;
3927
4002
  case GGML_UNARY_OP_RELU:
3928
- func = ggml_sycl_relu;
4003
+ ggml_sycl_relu(ctx, dst);
3929
4004
  break;
3930
4005
  case GGML_UNARY_OP_SIGMOID:
3931
- func = ggml_sycl_sigmoid;
4006
+ ggml_sycl_sigmoid(ctx, dst);
3932
4007
  break;
3933
4008
  case GGML_UNARY_OP_HARDSIGMOID:
3934
- func = ggml_sycl_hardsigmoid;
4009
+ ggml_sycl_hardsigmoid(ctx, dst);
3935
4010
  break;
3936
4011
  case GGML_UNARY_OP_HARDSWISH:
3937
- func = ggml_sycl_hardswish;
4012
+ ggml_sycl_hardswish(ctx, dst);
3938
4013
  break;
3939
4014
  case GGML_UNARY_OP_EXP:
3940
- func = ggml_sycl_exp;
4015
+ ggml_sycl_exp(ctx, dst);
3941
4016
  break;
3942
4017
  default:
3943
4018
  return false;
3944
4019
  }
3945
4020
  break;
3946
4021
  case GGML_OP_NORM:
3947
- func = ggml_sycl_norm;
4022
+ ggml_sycl_norm(ctx, dst);
3948
4023
  break;
3949
4024
  case GGML_OP_GROUP_NORM:
3950
- func = ggml_sycl_group_norm;
4025
+ ggml_sycl_group_norm(ctx, dst);
3951
4026
  break;
3952
4027
  case GGML_OP_CONCAT:
3953
- func = ggml_sycl_op_concat;
4028
+ ggml_sycl_op_concat(ctx, dst);
3954
4029
  break;
3955
4030
  case GGML_OP_UPSCALE:
3956
- func = ggml_sycl_upscale;
4031
+ ggml_sycl_upscale(ctx, dst);
3957
4032
  break;
3958
4033
  case GGML_OP_PAD:
3959
- func = ggml_sycl_pad;
4034
+ ggml_sycl_pad(ctx, dst);
3960
4035
  break;
3961
4036
  case GGML_OP_LEAKY_RELU:
3962
- func = ggml_sycl_leaky_relu;
4037
+ ggml_sycl_leaky_relu(ctx, dst);
3963
4038
  break;
3964
4039
  case GGML_OP_RMS_NORM:
3965
- func = ggml_sycl_rms_norm;
4040
+ ggml_sycl_rms_norm(ctx, dst);
3966
4041
  break;
3967
4042
  case GGML_OP_MUL_MAT:
3968
- if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
4043
+ if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
3969
4044
  return false;
3970
4045
  }
3971
- func = ggml_sycl_mul_mat;
4046
+ /* ggml_sycl_mul_mat_id is dependent on ggml_sycl_mul_mat */
4047
+ ggml_sycl_mul_mat(ctx, dst->src[0], dst->src[1], dst);
3972
4048
  break;
3973
4049
  case GGML_OP_MUL_MAT_ID:
3974
- if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
4050
+ if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
3975
4051
  return false;
3976
4052
  }
3977
- func = ggml_sycl_mul_mat_id;
4053
+ ggml_sycl_mul_mat_id(ctx, dst);
3978
4054
  break;
3979
4055
  case GGML_OP_OUT_PROD:
3980
- func = ggml_sycl_op_out_prod;
4056
+ ggml_sycl_op_out_prod(ctx, dst);
3981
4057
  break;
3982
4058
  case GGML_OP_SCALE:
3983
- func = ggml_sycl_scale;
4059
+ ggml_sycl_scale(ctx, dst);
3984
4060
  break;
3985
4061
  case GGML_OP_SQR:
3986
- func = ggml_sycl_sqr;
4062
+ ggml_sycl_sqr(ctx, dst);
3987
4063
  break;
3988
4064
  case GGML_OP_SQRT:
3989
- func = ggml_sycl_sqrt;
4065
+ ggml_sycl_sqrt(ctx, dst);
3990
4066
  break;
3991
4067
  case GGML_OP_SIN:
3992
- func = ggml_sycl_sin;
4068
+ ggml_sycl_sin(ctx, dst);
3993
4069
  break;
3994
4070
  case GGML_OP_COS:
3995
- func = ggml_sycl_cos;
4071
+ ggml_sycl_cos(ctx, dst);
3996
4072
  break;
3997
4073
  case GGML_OP_CLAMP:
3998
- func = ggml_sycl_clamp;
4074
+ ggml_sycl_clamp(ctx, dst);
3999
4075
  break;
4000
4076
  case GGML_OP_CPY:
4001
- func = ggml_sycl_cpy;
4077
+ ggml_sycl_cpy(ctx, dst->src[0], dst->src[1], dst);
4002
4078
  break;
4003
4079
  case GGML_OP_CONT:
4004
- func = ggml_sycl_dup;
4080
+ ggml_sycl_dup(ctx, dst);
4005
4081
  break;
4006
4082
  case GGML_OP_NONE:
4007
4083
  case GGML_OP_RESHAPE:
4008
4084
  case GGML_OP_VIEW:
4009
4085
  case GGML_OP_PERMUTE:
4010
4086
  case GGML_OP_TRANSPOSE:
4011
- func = ggml_sycl_nop;
4087
+ GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
4012
4088
  break;
4013
4089
  case GGML_OP_DIAG_MASK_INF:
4014
- func = ggml_sycl_diag_mask_inf;
4090
+ ggml_sycl_diag_mask_inf(ctx, dst);
4015
4091
  break;
4016
4092
  case GGML_OP_SOFT_MAX:
4017
- func = ggml_sycl_soft_max;
4093
+ ggml_sycl_soft_max(ctx, dst);
4018
4094
  break;
4019
4095
  case GGML_OP_ROPE:
4020
- func = ggml_sycl_rope;
4096
+ ggml_sycl_rope(ctx, dst);
4021
4097
  break;
4022
4098
  case GGML_OP_IM2COL:
4023
- func = ggml_sycl_im2col;
4099
+ ggml_sycl_im2col(ctx, dst);
4024
4100
  break;
4025
4101
  case GGML_OP_POOL_2D:
4026
- func = ggml_sycl_pool2d;
4102
+ ggml_sycl_pool2d(ctx, dst);
4027
4103
  break;
4028
4104
  case GGML_OP_SUM:
4029
- func = ggml_sycl_sum;
4105
+ ggml_sycl_sum(ctx, dst);
4030
4106
  break;
4031
4107
  case GGML_OP_SUM_ROWS:
4032
- func = ggml_sycl_sum_rows;
4108
+ ggml_sycl_sum_rows(ctx, dst);
4033
4109
  break;
4034
4110
  case GGML_OP_ARGSORT:
4035
- func = ggml_sycl_argsort;
4111
+ ggml_sycl_argsort(ctx, dst);
4036
4112
  break;
4037
4113
  case GGML_OP_TIMESTEP_EMBEDDING:
4038
- func = ggml_sycl_op_timestep_embedding;
4114
+ ggml_sycl_op_timestep_embedding(ctx, dst);
4039
4115
  break;
4040
4116
  case GGML_OP_RWKV_WKV6:
4041
- func = ggml_sycl_op_rwkv_wkv6;
4117
+ ggml_sycl_op_rwkv_wkv6(ctx, dst);
4118
+ break;
4119
+ case GGML_OP_GATED_LINEAR_ATTN:
4120
+ ggml_sycl_op_gated_linear_attn(ctx, dst);
4042
4121
  break;
4043
4122
  default:
4044
4123
  return false;
4045
4124
  }
4046
4125
 
4047
- if (tensor->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(tensor->src[0]->buffer)) {
4048
- ggml_sycl_set_peer_access(tensor->src[1]->ne[1], ctx.device);
4049
- }
4050
-
4051
- func(ctx, tensor->src[0], tensor->src[1], tensor);
4052
4126
  return true;
4053
4127
  }
4054
4128
 
@@ -4512,6 +4586,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4512
4586
  case GGML_OP_LEAKY_RELU:
4513
4587
  case GGML_OP_TIMESTEP_EMBEDDING:
4514
4588
  case GGML_OP_RWKV_WKV6:
4589
+ case GGML_OP_GATED_LINEAR_ATTN:
4515
4590
  return true;
4516
4591
  default:
4517
4592
  return false;
@@ -4638,10 +4713,9 @@ static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t re
4638
4713
  static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) {
4639
4714
  GGML_UNUSED(reg);
4640
4715
 
4641
- // TODO: update to the current function signature
4642
- //if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
4643
- // return (void *)ggml_backend_sycl_split_buffer_type;
4644
- //}
4716
+ if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
4717
+ return (void *)ggml_backend_sycl_split_buffer_type;
4718
+ }
4645
4719
 
4646
4720
  // SYCL doesn't support registering host memory, left here for reference
4647
4721
  // "ggml_backend_register_host_buffer"