@fugood/llama.node 0.3.14 → 0.3.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/llama.cpp/.github/workflows/build.yml +30 -1
  19. package/src/llama.cpp/CMakeLists.txt +9 -1
  20. package/src/llama.cpp/cmake/common.cmake +2 -0
  21. package/src/llama.cpp/common/arg.cpp +20 -2
  22. package/src/llama.cpp/common/common.cpp +6 -3
  23. package/src/llama.cpp/common/speculative.cpp +4 -4
  24. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  25. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
  26. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
  27. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  28. package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
  29. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  30. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  31. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
  32. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
  33. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
  34. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  35. package/src/llama.cpp/examples/main/main.cpp +6 -6
  36. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
  37. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  38. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  39. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  40. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  41. package/src/llama.cpp/examples/run/run.cpp +91 -46
  42. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  43. package/src/llama.cpp/examples/server/server.cpp +37 -15
  44. package/src/llama.cpp/examples/server/utils.hpp +3 -1
  45. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  46. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  47. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  48. package/src/llama.cpp/examples/tts/tts.cpp +20 -9
  49. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  50. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  51. package/src/llama.cpp/ggml/include/ggml.h +24 -0
  52. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -28
  53. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  54. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
  56. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +1493 -12
  57. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
  58. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +284 -29
  59. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  60. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  61. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
  62. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  63. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
  64. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +35 -12
  65. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
  66. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +93 -27
  67. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  68. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
  69. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  70. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  71. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
  72. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +109 -40
  73. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  74. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
  75. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  76. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  77. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
  78. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  79. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +398 -158
  81. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  82. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +7 -2
  83. package/src/llama.cpp/ggml/src/ggml.c +85 -2
  84. package/src/llama.cpp/include/llama.h +86 -22
  85. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  86. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  87. package/src/llama.cpp/src/llama-adapter.h +11 -9
  88. package/src/llama.cpp/src/llama-arch.cpp +103 -16
  89. package/src/llama.cpp/src/llama-arch.h +18 -0
  90. package/src/llama.cpp/src/llama-batch.h +2 -2
  91. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  92. package/src/llama.cpp/src/llama-context.h +214 -77
  93. package/src/llama.cpp/src/llama-cparams.h +1 -0
  94. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  95. package/src/llama.cpp/src/llama-graph.h +574 -0
  96. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  97. package/src/llama.cpp/src/llama-hparams.h +9 -0
  98. package/src/llama.cpp/src/llama-io.cpp +15 -0
  99. package/src/llama.cpp/src/llama-io.h +35 -0
  100. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  101. package/src/llama.cpp/src/llama-kv-cache.h +178 -110
  102. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  103. package/src/llama.cpp/src/llama-memory.h +21 -0
  104. package/src/llama.cpp/src/llama-model.cpp +8244 -173
  105. package/src/llama.cpp/src/llama-model.h +34 -1
  106. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  107. package/src/llama.cpp/src/llama.cpp +51 -9984
  108. package/src/llama.cpp/tests/test-backend-ops.cpp +145 -23
  109. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  110. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
@@ -297,8 +297,27 @@ static int ggml_backend_opencl_n_devices = 0;
297
297
  struct ProfilingInfo {
298
298
  std::string op_name;
299
299
  std::string kernel_name;
300
- // Kernel execution time in nanoseconds.
301
- cl_ulong duration_ns;
300
+
301
+ cl_kernel kernel;
302
+ cl_event evt;
303
+
304
+ cl_ulong cmd_queued;
305
+ cl_ulong cmd_submit;
306
+ cl_ulong cmd_start;
307
+ cl_ulong cmd_end;
308
+ cl_ulong overhead_start;
309
+ cl_ulong overhead_end;
310
+ // For the times below, see spec for clGetEventProfilingInfo
311
+ // The time kernel spent in cmd queue - SUBMIT - QUEUED
312
+ cl_ulong cmd_queued_duration_ns;
313
+ // The time kernel spent for submission - START - SUBMIT
314
+ cl_ulong cmd_submit_duration_ns;
315
+ // Kernel execution time in nanoseconds - END - START
316
+ cl_ulong cmd_duration_ns;
317
+ // The time for the kernel to complete - COMPLETE - END
318
+ cl_ulong cmd_complete_duration_ns;
319
+ // Total time to finish the kernel - COMPELTE - QUEUED
320
+ cl_ulong cmd_total_duration_ns;
302
321
  // Global and local work sizes.
303
322
  size_t global_size[3];
304
323
  size_t local_size[3];
@@ -903,12 +922,56 @@ static void ggml_cl2_free(void) {
903
922
  return;
904
923
  }
905
924
 
925
+ // Populate profiling info
926
+ for (ProfilingInfo & info : g_profiling_info) {
927
+ cl_ulong cmd_queued;
928
+ cl_ulong cmd_submit;
929
+ cl_ulong cmd_start;
930
+ cl_ulong cmd_end;
931
+ cl_ulong cmd_complete;
932
+
933
+ CL_CHECK(clWaitForEvents(1, &info.evt));
934
+ CL_CHECK(clGetEventProfilingInfo(
935
+ info.evt, CL_PROFILING_COMMAND_QUEUED, sizeof(cl_ulong), &cmd_queued, NULL));
936
+ CL_CHECK(clGetEventProfilingInfo(
937
+ info.evt, CL_PROFILING_COMMAND_SUBMIT, sizeof(cl_ulong), &cmd_submit, NULL));
938
+ CL_CHECK(clGetEventProfilingInfo(
939
+ info.evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &cmd_start, NULL));
940
+ CL_CHECK(clGetEventProfilingInfo(
941
+ info.evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &cmd_end, NULL));
942
+ CL_CHECK(clGetEventProfilingInfo(
943
+ info.evt, CL_PROFILING_COMMAND_COMPLETE, sizeof(cl_ulong), &cmd_complete, NULL));
944
+ CL_CHECK(clReleaseEvent(info.evt));
945
+
946
+ char kernel_name[512];
947
+ CL_CHECK(clGetKernelInfo(info.kernel, CL_KERNEL_FUNCTION_NAME,
948
+ sizeof(kernel_name), kernel_name, NULL));
949
+ info.kernel_name = kernel_name;
950
+
951
+ info.cmd_queued = cmd_queued;
952
+ info.cmd_submit = cmd_submit;
953
+ info.cmd_start = cmd_start;
954
+ info.cmd_end = cmd_end;
955
+
956
+ info.cmd_queued_duration_ns = cmd_submit - cmd_queued;
957
+ info.cmd_submit_duration_ns = cmd_start - cmd_submit;
958
+ info.cmd_duration_ns = cmd_end - cmd_start;
959
+ info.cmd_complete_duration_ns = cmd_complete - cmd_end;
960
+ info.cmd_total_duration_ns = cmd_complete - cmd_queued;
961
+ }
962
+
963
+ // Dump a csv
906
964
  float total_kernel_time = 0;
907
- fprintf(fperf, "op name, kernel name, duration (ms), global size, local size, output size\n");
965
+ fprintf(fperf, "op name, kernel name, queued duration (ms), submit duration(ms), exec duration (ms), complete duration (ms), total duration (ms), global size, local size, output size\n");
908
966
  for (const ProfilingInfo & info : g_profiling_info) {
909
- total_kernel_time += info.duration_ns/1.e6f;
910
- fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
911
- info.op_name.c_str(), info.kernel_name.c_str(), info.duration_ns/1.e6f,
967
+ total_kernel_time += info.cmd_duration_ns/1.e6f;
968
+ fprintf(fperf, "%s,%s,%f,%f,%f,%f,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
969
+ info.op_name.c_str(), info.kernel_name.c_str(),
970
+ info.cmd_queued_duration_ns/1.e6f,
971
+ info.cmd_submit_duration_ns/1.e6f,
972
+ info.cmd_duration_ns/1.e6f,
973
+ info.cmd_complete_duration_ns/1.e6f,
974
+ info.cmd_total_duration_ns/1.e6f,
912
975
  info.global_size[0], info.global_size[1], info.global_size[2],
913
976
  info.local_size[0], info.local_size[2], info.local_size[2],
914
977
  info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]);
@@ -916,6 +979,27 @@ static void ggml_cl2_free(void) {
916
979
  fclose(fperf);
917
980
 
918
981
  GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time);
982
+
983
+ // Dump a simple chrome trace
984
+ FILE* ftrace = fopen("cl_trace.json", "w");
985
+ if (!ftrace) {
986
+ GGML_LOG_ERROR("Failed to open cl_trace.json\n");
987
+ return;
988
+ }
989
+
990
+ fprintf(ftrace, "[\n");
991
+ for (const ProfilingInfo & info : g_profiling_info) {
992
+ fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
993
+ info.kernel_name.c_str(), info.cmd_queued/1000);
994
+ fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
995
+ info.kernel_name.c_str(), info.cmd_submit/1000);
996
+
997
+ fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
998
+ info.kernel_name.c_str(), info.cmd_start/1000);
999
+ fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
1000
+ info.kernel_name.c_str(), info.cmd_end/1000);
1001
+ }
1002
+ fclose(ftrace);
919
1003
  #endif
920
1004
  }
921
1005
 
@@ -2062,25 +2146,14 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
2062
2146
  // Profiling utility
2063
2147
  //------------------------------------------------------------------------------
2064
2148
  #ifdef GGML_OPENCL_PROFILING
2065
- void populateProfilingInfo(
2149
+ static void populateProfilingInfo(
2066
2150
  ProfilingInfo& info, cl_event evt, cl_kernel kernel,
2067
2151
  size_t global_size[3], size_t local_size[3],
2068
2152
  const ggml_tensor * tensor) {
2069
- cl_ulong start;
2070
- cl_ulong end;
2071
- CL_CHECK(clWaitForEvents(1, &evt));
2072
- CL_CHECK(clGetEventProfilingInfo(
2073
- evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &start, NULL));
2074
- CL_CHECK(clGetEventProfilingInfo(
2075
- evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &end, NULL));
2076
-
2077
- char kernel_name[512];
2078
- CL_CHECK(clGetKernelInfo(kernel, CL_KERNEL_FUNCTION_NAME,
2079
- sizeof(kernel_name), kernel_name, NULL));
2080
-
2081
- info.duration_ns = end - start;
2082
- info.op_name = tensor->name;
2083
- info.kernel_name = kernel_name;
2153
+ info.op_name = tensor->name;
2154
+ info.kernel = kernel;
2155
+ info.evt = evt;
2156
+
2084
2157
  info.local_size[0] = local_size[0];
2085
2158
  info.local_size[1] = local_size[1];
2086
2159
  info.local_size[2] = local_size[2];
@@ -23,6 +23,38 @@ ggml_add_backend_library(ggml-sycl
23
23
  ../../include/ggml-sycl.h
24
24
  )
25
25
 
26
+ find_package(DNNL)
27
+ set(GGML_SYCL_DNNL 0)
28
+ if(DNNL_FOUND)
29
+ if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR)
30
+ # Assuming oneDNN packaged with oneapi release is used which
31
+ # supports only intel target
32
+ set(DNNL_GPU_VENDOR "INTEL")
33
+ if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
34
+ message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
35
+ endif()
36
+ endif()
37
+
38
+ # Verify oneDNN was compiled for the same target as llama
39
+ if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
40
+ target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
41
+ set(GGML_SYCL_DNNL 1)
42
+ get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
43
+ foreach(CONFIG ${CONFIGS})
44
+ get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
45
+ message(STATUS "Found oneDNN: ${DNNL_LIB}")
46
+ endforeach()
47
+ else()
48
+ message(WARNING
49
+ "oneDNN must be compiled for the same target as llama.cpp.
50
+ llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
51
+ Disabling oneDNN support.")
52
+ endif()
53
+ else()
54
+ message(STATUS "oneDNN not found, disabling oneDNN support")
55
+ endif()
56
+ target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
57
+
26
58
  if (GGML_SYCL_F16)
27
59
  if (GGML_SYCL_TARGET STREQUAL "AMD")
28
60
  message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
@@ -48,24 +80,15 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp")
48
80
  file(GLOB GGML_SOURCES_SYCL "*.cpp")
49
81
  target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
50
82
 
51
- find_package(DNNL)
52
- message("-- DNNL found:" ${DNNL_FOUND})
53
-
54
- if (GGML_SYCL_TARGET STREQUAL "INTEL")
55
- add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
56
- else()
57
- add_compile_definitions(GGML_SYCL_DNNL=0)
58
- endif()
59
-
60
- if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
61
- target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
62
- endif()
63
83
 
64
84
  if (WIN32)
65
85
  find_package(IntelSYCL REQUIRED)
66
86
  find_package(MKL REQUIRED)
67
87
  target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
68
88
  else()
89
+ if (GGML_SYCL_GRAPH)
90
+ add_compile_definitions(GGML_SYCL_GRAPH)
91
+ endif()
69
92
  if (GGML_SYCL_TARGET STREQUAL "INTEL")
70
93
  target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
71
94
  elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
@@ -26,7 +26,7 @@
26
26
  #include "softmax.hpp"
27
27
  #include "tsembd.hpp"
28
28
  #include "im2col.hpp"
29
- #include "wkv6.hpp"
29
+ #include "wkv.hpp"
30
30
  #include "outprod.hpp"
31
31
  #include "element_wise.hpp"
32
32
  #include "cpy.hpp"
@@ -170,7 +170,6 @@ static size_t g_scratch_offset = 0;
170
170
  int get_current_device_id();
171
171
 
172
172
  inline dpct::err0 ggml_sycl_set_device(const int device) try {
173
-
174
173
  int current_device_id;
175
174
  SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
176
175
 
@@ -242,6 +241,14 @@ struct ggml_sycl_pool_alloc {
242
241
  }
243
242
  }
244
243
 
244
+ T * realloc(size_t size) {
245
+ GGML_ASSERT(pool != nullptr);
246
+ if (ptr)
247
+ pool->free(ptr, actual_size);
248
+ ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
249
+ return ptr;
250
+ }
251
+
245
252
  // size is in number of elements
246
253
  T * alloc(size_t size) {
247
254
  GGML_ASSERT(pool != nullptr);
@@ -301,6 +308,7 @@ inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
301
308
  return opt;
302
309
  }
303
310
 
311
+ namespace sycl_ex = sycl::ext::oneapi::experimental;
304
312
  struct ggml_backend_sycl_context {
305
313
  int device;
306
314
  std::string name;
@@ -370,10 +378,29 @@ struct ggml_backend_sycl_context {
370
378
  dnnl::stream stream_dnnl() {
371
379
  return stream_dnnl(device, 0);
372
380
  }
381
+ dnnl::memory get_scratchpad_mem(const dnnl::memory::desc & scratchpad_md,
382
+ const dnnl::engine & eng, const queue_ptr q) {
383
+ ggml_sycl_pool_alloc<uint8_t> * pool;
384
+ auto it = scratchpad_map.find(q);
385
+ if (it == scratchpad_map.end()) {
386
+ scratchpad_map[q] = std::make_unique<ggml_sycl_pool_alloc<uint8_t>>(this->pool());
387
+ pool = scratchpad_map[q].get();
388
+ } else {
389
+ pool = it->second.get();
390
+ }
391
+
392
+ size_t scratchpad_size = scratchpad_md.get_size();
393
+ if (scratchpad_size > pool->actual_size) {
394
+ pool->realloc(scratchpad_size);
395
+ }
396
+ void * mem_ptr = pool->get();
397
+ return dnnl::memory(scratchpad_md, eng, mem_ptr);
398
+ }
373
399
  #endif
374
400
 
375
401
  // pool
376
402
  std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
403
+ std::unordered_map<sycl::queue *, std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>>> scratchpad_map;
377
404
 
378
405
  std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
379
406
 
@@ -392,6 +419,10 @@ struct ggml_backend_sycl_context {
392
419
  return pool(device);
393
420
  }
394
421
 
422
+ #ifdef GGML_SYCL_GRAPH
423
+ std::unique_ptr<sycl_ex::command_graph<sycl_ex::graph_state::executable>> exec_graph = nullptr;
424
+ #endif
425
+
395
426
  ggml_sycl_pool & host_pool(int device) {
396
427
  if (host_pools[device] == nullptr) {
397
428
  host_pools[device] = new_pool_for_host(stream(device, 0), device);
@@ -474,6 +505,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
474
505
  int ne0, int ne1, int ne2, int ne3,
475
506
  int ne10, int ne11, int ne12, int ne13,
476
507
  /*int s0, */ int s1, int s2, int s3,
508
+ /*int s00,*/ int s01, int s02, int s03,
477
509
  /*int s10,*/ int s11, int s12, int s13,
478
510
  const sycl::nd_item<3> &item_ct1) {
479
511
  const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -495,9 +527,9 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
495
527
  const int i12 = i2 % ne12;
496
528
  const int i13 = i3 % ne13;
497
529
 
498
- const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
530
+ const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
499
531
  const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
500
- const size_t i_dst = i_src0;
532
+ const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
501
533
 
502
534
  const src0_t * src0_row = src0 + i_src0;
503
535
  const src1_t * src1_row = src1 + i_src1;
@@ -515,6 +547,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
515
547
  int ne0, int ne1, int ne2, int ne3,
516
548
  int ne10, int ne11, int ne12, int ne13,
517
549
  /*int s0, */ int s1, int s2, int s3,
550
+ /*int s00,*/ int s01, int s02, int s03,
518
551
  /*int s10,*/ int s11, int s12, int s13,
519
552
  const sycl::nd_item<3> &item_ct1) {
520
553
 
@@ -534,9 +567,9 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
534
567
  const int i12 = i2 % ne12;
535
568
  const int i13 = i3 % ne13;
536
569
 
537
- const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
570
+ const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
538
571
  const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
539
- const size_t i_dst = i_src0;
572
+ const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
540
573
 
541
574
  const src0_t * src0_row = src0 + i_src0;
542
575
  const src1_t * src1_row = src1 + i_src1;
@@ -566,9 +599,11 @@ struct bin_bcast_sycl {
566
599
  int nr[4] = { nr0, nr1, nr2, nr3 };
567
600
 
568
601
  // collapse dimensions until first broadcast dimension
569
- int64_t cne0[] = {ne0, ne1, ne2, ne3};
602
+ int64_t cne[] = {ne0, ne1, ne2, ne3};
603
+ int64_t cne0[] = {ne00, ne01, ne02, ne03};
570
604
  int64_t cne1[] = {ne10, ne11, ne12, ne13};
571
- size_t cnb0[] = {nb0, nb1, nb2, nb3};
605
+ size_t cnb[] = {nb0, nb1, nb2, nb3};
606
+ size_t cnb0[] = {nb00, nb01, nb02, nb03};
572
607
  size_t cnb1[] = {nb10, nb11, nb12, nb13};
573
608
  auto collapse = [](int64_t cne[]) {
574
609
  cne[0] *= cne[1];
@@ -583,32 +618,41 @@ struct bin_bcast_sycl {
583
618
  cnb[3] *= cne[3];
584
619
  };
585
620
 
586
- for (int i = 0; i < 4; i++) {
587
- if (nr[i] != 1) {
588
- break;
589
- }
590
- if (i > 0) {
591
- collapse_nb(cnb0, cne0);
592
- collapse_nb(cnb1, cne1);
593
- collapse(cne0);
594
- collapse(cne1);
621
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
622
+ for (int i = 0; i < 4; i++) {
623
+ if (nr[i] != 1) {
624
+ break;
625
+ }
626
+ if (i > 0) {
627
+ collapse_nb(cnb, cne);
628
+ collapse_nb(cnb0, cne0);
629
+ collapse_nb(cnb1, cne1);
630
+ collapse(cne);
631
+ collapse(cne0);
632
+ collapse(cne1);
633
+ }
595
634
  }
596
635
  }
597
636
  {
598
- int64_t ne0 = cne0[0];
599
- int64_t ne1 = cne0[1];
600
- int64_t ne2 = cne0[2];
601
- int64_t ne3 = cne0[3];
637
+ int64_t ne0 = cne[0];
638
+ int64_t ne1 = cne[1];
639
+ int64_t ne2 = cne[2];
640
+ int64_t ne3 = cne[3];
602
641
 
603
642
  int64_t ne10 = cne1[0];
604
643
  int64_t ne11 = cne1[1];
605
644
  int64_t ne12 = cne1[2];
606
645
  int64_t ne13 = cne1[3];
607
646
 
608
- size_t nb0 = cnb0[0];
609
- size_t nb1 = cnb0[1];
610
- size_t nb2 = cnb0[2];
611
- size_t nb3 = cnb0[3];
647
+ size_t nb0 = cnb[0];
648
+ size_t nb1 = cnb[1];
649
+ size_t nb2 = cnb[2];
650
+ size_t nb3 = cnb[3];
651
+
652
+ size_t nb00 = cnb0[0];
653
+ size_t nb01 = cnb0[1];
654
+ size_t nb02 = cnb0[2];
655
+ size_t nb03 = cnb0[3];
612
656
 
613
657
  size_t nb10 = cnb1[0];
614
658
  size_t nb11 = cnb1[1];
@@ -625,6 +669,28 @@ struct bin_bcast_sycl {
625
669
  size_t s12 = nb12 / sizeof(src1_t);
626
670
  size_t s13 = nb13 / sizeof(src1_t);
627
671
 
672
+ size_t s00 = nb00 / sizeof(src0_t);
673
+ size_t s01 = nb01 / sizeof(src0_t);
674
+ size_t s02 = nb02 / sizeof(src0_t);
675
+ size_t s03 = nb03 / sizeof(src0_t);
676
+
677
+ GGML_UNUSED(s00);
678
+
679
+ GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
680
+ GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
681
+ GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
682
+ GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
683
+
684
+ GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
685
+ GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
686
+ GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
687
+ GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
688
+
689
+ GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
690
+ GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
691
+ GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
692
+ GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
693
+
628
694
  GGML_ASSERT(s0 == 1);
629
695
  GGML_ASSERT(s10 == 1);
630
696
 
@@ -661,8 +727,8 @@ struct bin_bcast_sycl {
661
727
  [=](sycl::nd_item<3> item_ct1) {
662
728
  k_bin_bcast_unravel<bin_op>(
663
729
  src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
664
- ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
665
- s13, item_ct1);
730
+ ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
731
+ s03, s11, s12, s13, item_ct1);
666
732
  });
667
733
  }
668
734
  } else {
@@ -680,7 +746,7 @@ struct bin_bcast_sycl {
680
746
  [=](sycl::nd_item<3> item_ct1) {
681
747
  k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
682
748
  ne2, ne3, ne10, ne11, ne12, ne13,
683
- s1, s2, s3, s11, s12, s13,
749
+ s1, s2, s3, s01, s02, s03, s11, s12, s13,
684
750
  item_ct1);
685
751
  });
686
752
  }
@@ -138,7 +138,7 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
138
138
  stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
139
139
  sycl::range<3>(1, 1, WARP_SIZE),
140
140
  sycl::range<3>(1, 1, WARP_SIZE)),
141
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
141
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
142
142
  dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
143
143
  });
144
144
 
@@ -210,7 +210,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
210
210
 
211
211
  stream->parallel_for(
212
212
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
213
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
213
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
214
214
  dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
215
215
  nrows, item_ct1);
216
216
  });
@@ -879,7 +879,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
879
879
 
880
880
  stream->parallel_for(
881
881
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
882
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
882
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
883
883
  dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
884
884
  vx, y, dst, ncols, nrows, item_ct1);
885
885
  });
@@ -902,7 +902,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
902
902
 
903
903
  stream->parallel_for(
904
904
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
905
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
905
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
906
906
  dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
907
907
  vx, y, dst, ncols, nrows, item_ct1);
908
908
  });
@@ -923,7 +923,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
923
923
 
924
924
  stream->parallel_for(
925
925
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
926
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
926
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
927
927
  dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
928
928
  vx, y, dst, ncols, nrows, item_ct1);
929
929
  });
@@ -944,7 +944,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
944
944
 
945
945
  stream->parallel_for(
946
946
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
947
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
947
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
948
948
  dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
949
949
  vx, y, dst, ncols, nrows, item_ct1);
950
950
  });
@@ -965,7 +965,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
965
965
 
966
966
  stream->parallel_for(
967
967
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
968
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
968
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
969
969
  dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
970
970
  vx, y, dst, ncols, nrows, item_ct1);
971
971
  });
@@ -986,7 +986,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
986
986
 
987
987
  stream->parallel_for(
988
988
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
989
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
989
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
990
990
  dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
991
991
  vx, y, dst, ncols, nrows, item_ct1);
992
992
  });
@@ -1004,7 +1004,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
1004
1004
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1005
1005
  stream->parallel_for(
1006
1006
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1007
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1007
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1008
1008
  dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
1009
1009
  });
1010
1010
  }
@@ -1020,7 +1020,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
1020
1020
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1021
1021
  stream->parallel_for(
1022
1022
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1023
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1023
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1024
1024
  dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
1025
1025
  });
1026
1026
  }
@@ -1036,7 +1036,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
1036
1036
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1037
1037
  stream->parallel_for(
1038
1038
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1039
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1039
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1040
1040
  dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
1041
1041
  });
1042
1042
  }
@@ -1049,7 +1049,7 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
1049
1049
  const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
1050
1050
  stream->parallel_for(
1051
1051
  sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
1052
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1052
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1053
1053
  dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
1054
1054
  });
1055
1055
  }
@@ -1065,7 +1065,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
1065
1065
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1066
1066
  stream->parallel_for(
1067
1067
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1068
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1068
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1069
1069
  dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
1070
1070
  });
1071
1071
  }
@@ -1143,7 +1143,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
1143
1143
  default:
1144
1144
  printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
1145
1145
  GGML_ABORT("fatal error");
1146
- break;
1147
1146
  }
1148
1147
 
1149
1148
  GGML_UNUSED(src1);