@fugood/llama.node 0.3.14 → 0.3.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/llama.cpp/.github/workflows/build.yml +30 -1
  19. package/src/llama.cpp/CMakeLists.txt +9 -1
  20. package/src/llama.cpp/cmake/common.cmake +2 -0
  21. package/src/llama.cpp/common/arg.cpp +20 -2
  22. package/src/llama.cpp/common/common.cpp +6 -3
  23. package/src/llama.cpp/common/speculative.cpp +4 -4
  24. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  25. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
  26. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
  27. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  28. package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
  29. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  30. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  31. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
  32. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
  33. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
  34. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  35. package/src/llama.cpp/examples/main/main.cpp +6 -6
  36. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
  37. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  38. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  39. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  40. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  41. package/src/llama.cpp/examples/run/run.cpp +91 -46
  42. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  43. package/src/llama.cpp/examples/server/server.cpp +32 -15
  44. package/src/llama.cpp/examples/server/utils.hpp +3 -1
  45. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  46. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  47. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  48. package/src/llama.cpp/examples/tts/tts.cpp +12 -9
  49. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  50. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  51. package/src/llama.cpp/ggml/include/ggml.h +24 -0
  52. package/src/llama.cpp/ggml/src/CMakeLists.txt +5 -27
  53. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  54. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
  56. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
  57. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +253 -2
  58. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  59. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  60. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
  61. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  62. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
  63. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -0
  64. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
  65. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +66 -26
  66. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  67. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
  68. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  69. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
  70. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +103 -34
  71. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  72. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
  73. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  74. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  75. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
  76. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  77. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  78. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +352 -146
  79. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +3 -0
  81. package/src/llama.cpp/ggml/src/ggml.c +85 -2
  82. package/src/llama.cpp/include/llama.h +86 -22
  83. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  84. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  85. package/src/llama.cpp/src/llama-adapter.h +11 -9
  86. package/src/llama.cpp/src/llama-arch.cpp +102 -16
  87. package/src/llama.cpp/src/llama-arch.h +18 -0
  88. package/src/llama.cpp/src/llama-batch.h +2 -2
  89. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  90. package/src/llama.cpp/src/llama-context.h +214 -77
  91. package/src/llama.cpp/src/llama-cparams.h +1 -0
  92. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  93. package/src/llama.cpp/src/llama-graph.h +574 -0
  94. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  95. package/src/llama.cpp/src/llama-hparams.h +9 -0
  96. package/src/llama.cpp/src/llama-io.cpp +15 -0
  97. package/src/llama.cpp/src/llama-io.h +35 -0
  98. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  99. package/src/llama.cpp/src/llama-kv-cache.h +178 -110
  100. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  101. package/src/llama.cpp/src/llama-memory.h +21 -0
  102. package/src/llama.cpp/src/llama-model.cpp +8207 -163
  103. package/src/llama.cpp/src/llama-model.h +34 -1
  104. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  105. package/src/llama.cpp/src/llama.cpp +51 -9984
  106. package/src/llama.cpp/tests/test-backend-ops.cpp +88 -9
  107. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  108. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
@@ -297,8 +297,27 @@ static int ggml_backend_opencl_n_devices = 0;
297
297
  struct ProfilingInfo {
298
298
  std::string op_name;
299
299
  std::string kernel_name;
300
- // Kernel execution time in nanoseconds.
301
- cl_ulong duration_ns;
300
+
301
+ cl_kernel kernel;
302
+ cl_event evt;
303
+
304
+ cl_ulong cmd_queued;
305
+ cl_ulong cmd_submit;
306
+ cl_ulong cmd_start;
307
+ cl_ulong cmd_end;
308
+ cl_ulong overhead_start;
309
+ cl_ulong overhead_end;
310
+ // For the times below, see spec for clGetEventProfilingInfo
311
+ // The time kernel spent in cmd queue - SUBMIT - QUEUED
312
+ cl_ulong cmd_queued_duration_ns;
313
+ // The time kernel spent for submission - START - SUBMIT
314
+ cl_ulong cmd_submit_duration_ns;
315
+ // Kernel execution time in nanoseconds - END - START
316
+ cl_ulong cmd_duration_ns;
317
+ // The time for the kernel to complete - COMPLETE - END
318
+ cl_ulong cmd_complete_duration_ns;
319
+ // Total time to finish the kernel - COMPELTE - QUEUED
320
+ cl_ulong cmd_total_duration_ns;
302
321
  // Global and local work sizes.
303
322
  size_t global_size[3];
304
323
  size_t local_size[3];
@@ -903,12 +922,56 @@ static void ggml_cl2_free(void) {
903
922
  return;
904
923
  }
905
924
 
925
+ // Populate profiling info
926
+ for (ProfilingInfo & info : g_profiling_info) {
927
+ cl_ulong cmd_queued;
928
+ cl_ulong cmd_submit;
929
+ cl_ulong cmd_start;
930
+ cl_ulong cmd_end;
931
+ cl_ulong cmd_complete;
932
+
933
+ CL_CHECK(clWaitForEvents(1, &info.evt));
934
+ CL_CHECK(clGetEventProfilingInfo(
935
+ info.evt, CL_PROFILING_COMMAND_QUEUED, sizeof(cl_ulong), &cmd_queued, NULL));
936
+ CL_CHECK(clGetEventProfilingInfo(
937
+ info.evt, CL_PROFILING_COMMAND_SUBMIT, sizeof(cl_ulong), &cmd_submit, NULL));
938
+ CL_CHECK(clGetEventProfilingInfo(
939
+ info.evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &cmd_start, NULL));
940
+ CL_CHECK(clGetEventProfilingInfo(
941
+ info.evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &cmd_end, NULL));
942
+ CL_CHECK(clGetEventProfilingInfo(
943
+ info.evt, CL_PROFILING_COMMAND_COMPLETE, sizeof(cl_ulong), &cmd_complete, NULL));
944
+ CL_CHECK(clReleaseEvent(info.evt));
945
+
946
+ char kernel_name[512];
947
+ CL_CHECK(clGetKernelInfo(info.kernel, CL_KERNEL_FUNCTION_NAME,
948
+ sizeof(kernel_name), kernel_name, NULL));
949
+ info.kernel_name = kernel_name;
950
+
951
+ info.cmd_queued = cmd_queued;
952
+ info.cmd_submit = cmd_submit;
953
+ info.cmd_start = cmd_start;
954
+ info.cmd_end = cmd_end;
955
+
956
+ info.cmd_queued_duration_ns = cmd_submit - cmd_queued;
957
+ info.cmd_submit_duration_ns = cmd_start - cmd_submit;
958
+ info.cmd_duration_ns = cmd_end - cmd_start;
959
+ info.cmd_complete_duration_ns = cmd_complete - cmd_end;
960
+ info.cmd_total_duration_ns = cmd_complete - cmd_queued;
961
+ }
962
+
963
+ // Dump a csv
906
964
  float total_kernel_time = 0;
907
- fprintf(fperf, "op name, kernel name, duration (ms), global size, local size, output size\n");
965
+ fprintf(fperf, "op name, kernel name, queued duration (ms), submit duration(ms), exec duration (ms), complete duration (ms), total duration (ms), global size, local size, output size\n");
908
966
  for (const ProfilingInfo & info : g_profiling_info) {
909
- total_kernel_time += info.duration_ns/1.e6f;
910
- fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
911
- info.op_name.c_str(), info.kernel_name.c_str(), info.duration_ns/1.e6f,
967
+ total_kernel_time += info.cmd_duration_ns/1.e6f;
968
+ fprintf(fperf, "%s,%s,%f,%f,%f,%f,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
969
+ info.op_name.c_str(), info.kernel_name.c_str(),
970
+ info.cmd_queued_duration_ns/1.e6f,
971
+ info.cmd_submit_duration_ns/1.e6f,
972
+ info.cmd_duration_ns/1.e6f,
973
+ info.cmd_complete_duration_ns/1.e6f,
974
+ info.cmd_total_duration_ns/1.e6f,
912
975
  info.global_size[0], info.global_size[1], info.global_size[2],
913
976
  info.local_size[0], info.local_size[2], info.local_size[2],
914
977
  info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]);
@@ -916,6 +979,27 @@ static void ggml_cl2_free(void) {
916
979
  fclose(fperf);
917
980
 
918
981
  GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time);
982
+
983
+ // Dump a simple chrome trace
984
+ FILE* ftrace = fopen("cl_trace.json", "w");
985
+ if (!ftrace) {
986
+ GGML_LOG_ERROR("Failed to open cl_trace.json\n");
987
+ return;
988
+ }
989
+
990
+ fprintf(ftrace, "[\n");
991
+ for (const ProfilingInfo & info : g_profiling_info) {
992
+ fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
993
+ info.kernel_name.c_str(), info.cmd_queued/1000);
994
+ fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
995
+ info.kernel_name.c_str(), info.cmd_submit/1000);
996
+
997
+ fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
998
+ info.kernel_name.c_str(), info.cmd_start/1000);
999
+ fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
1000
+ info.kernel_name.c_str(), info.cmd_end/1000);
1001
+ }
1002
+ fclose(ftrace);
919
1003
  #endif
920
1004
  }
921
1005
 
@@ -2062,25 +2146,14 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
2062
2146
  // Profiling utility
2063
2147
  //------------------------------------------------------------------------------
2064
2148
  #ifdef GGML_OPENCL_PROFILING
2065
- void populateProfilingInfo(
2149
+ static void populateProfilingInfo(
2066
2150
  ProfilingInfo& info, cl_event evt, cl_kernel kernel,
2067
2151
  size_t global_size[3], size_t local_size[3],
2068
2152
  const ggml_tensor * tensor) {
2069
- cl_ulong start;
2070
- cl_ulong end;
2071
- CL_CHECK(clWaitForEvents(1, &evt));
2072
- CL_CHECK(clGetEventProfilingInfo(
2073
- evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &start, NULL));
2074
- CL_CHECK(clGetEventProfilingInfo(
2075
- evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &end, NULL));
2076
-
2077
- char kernel_name[512];
2078
- CL_CHECK(clGetKernelInfo(kernel, CL_KERNEL_FUNCTION_NAME,
2079
- sizeof(kernel_name), kernel_name, NULL));
2080
-
2081
- info.duration_ns = end - start;
2082
- info.op_name = tensor->name;
2083
- info.kernel_name = kernel_name;
2153
+ info.op_name = tensor->name;
2154
+ info.kernel = kernel;
2155
+ info.evt = evt;
2156
+
2084
2157
  info.local_size[0] = local_size[0];
2085
2158
  info.local_size[1] = local_size[1];
2086
2159
  info.local_size[2] = local_size[2];
@@ -66,6 +66,9 @@ if (WIN32)
66
66
  find_package(MKL REQUIRED)
67
67
  target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
68
68
  else()
69
+ if (GGML_SYCL_GRAPH)
70
+ add_compile_definitions(GGML_SYCL_GRAPH)
71
+ endif()
69
72
  if (GGML_SYCL_TARGET STREQUAL "INTEL")
70
73
  target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
71
74
  elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
@@ -26,7 +26,7 @@
26
26
  #include "softmax.hpp"
27
27
  #include "tsembd.hpp"
28
28
  #include "im2col.hpp"
29
- #include "wkv6.hpp"
29
+ #include "wkv.hpp"
30
30
  #include "outprod.hpp"
31
31
  #include "element_wise.hpp"
32
32
  #include "cpy.hpp"
@@ -301,6 +301,7 @@ inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
301
301
  return opt;
302
302
  }
303
303
 
304
+ namespace sycl_ex = sycl::ext::oneapi::experimental;
304
305
  struct ggml_backend_sycl_context {
305
306
  int device;
306
307
  std::string name;
@@ -392,6 +393,10 @@ struct ggml_backend_sycl_context {
392
393
  return pool(device);
393
394
  }
394
395
 
396
+ #ifdef GGML_SYCL_GRAPH
397
+ std::unique_ptr<sycl_ex::command_graph<sycl_ex::graph_state::executable>> exec_graph = nullptr;
398
+ #endif
399
+
395
400
  ggml_sycl_pool & host_pool(int device) {
396
401
  if (host_pools[device] == nullptr) {
397
402
  host_pools[device] = new_pool_for_host(stream(device, 0), device);
@@ -474,6 +479,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
474
479
  int ne0, int ne1, int ne2, int ne3,
475
480
  int ne10, int ne11, int ne12, int ne13,
476
481
  /*int s0, */ int s1, int s2, int s3,
482
+ /*int s00,*/ int s01, int s02, int s03,
477
483
  /*int s10,*/ int s11, int s12, int s13,
478
484
  const sycl::nd_item<3> &item_ct1) {
479
485
  const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -495,9 +501,9 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
495
501
  const int i12 = i2 % ne12;
496
502
  const int i13 = i3 % ne13;
497
503
 
498
- const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
504
+ const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
499
505
  const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
500
- const size_t i_dst = i_src0;
506
+ const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
501
507
 
502
508
  const src0_t * src0_row = src0 + i_src0;
503
509
  const src1_t * src1_row = src1 + i_src1;
@@ -515,6 +521,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
515
521
  int ne0, int ne1, int ne2, int ne3,
516
522
  int ne10, int ne11, int ne12, int ne13,
517
523
  /*int s0, */ int s1, int s2, int s3,
524
+ /*int s00,*/ int s01, int s02, int s03,
518
525
  /*int s10,*/ int s11, int s12, int s13,
519
526
  const sycl::nd_item<3> &item_ct1) {
520
527
 
@@ -534,9 +541,9 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
534
541
  const int i12 = i2 % ne12;
535
542
  const int i13 = i3 % ne13;
536
543
 
537
- const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
544
+ const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
538
545
  const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
539
- const size_t i_dst = i_src0;
546
+ const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
540
547
 
541
548
  const src0_t * src0_row = src0 + i_src0;
542
549
  const src1_t * src1_row = src1 + i_src1;
@@ -566,9 +573,11 @@ struct bin_bcast_sycl {
566
573
  int nr[4] = { nr0, nr1, nr2, nr3 };
567
574
 
568
575
  // collapse dimensions until first broadcast dimension
569
- int64_t cne0[] = {ne0, ne1, ne2, ne3};
576
+ int64_t cne[] = {ne0, ne1, ne2, ne3};
577
+ int64_t cne0[] = {ne00, ne01, ne02, ne03};
570
578
  int64_t cne1[] = {ne10, ne11, ne12, ne13};
571
- size_t cnb0[] = {nb0, nb1, nb2, nb3};
579
+ size_t cnb[] = {nb0, nb1, nb2, nb3};
580
+ size_t cnb0[] = {nb00, nb01, nb02, nb03};
572
581
  size_t cnb1[] = {nb10, nb11, nb12, nb13};
573
582
  auto collapse = [](int64_t cne[]) {
574
583
  cne[0] *= cne[1];
@@ -583,32 +592,41 @@ struct bin_bcast_sycl {
583
592
  cnb[3] *= cne[3];
584
593
  };
585
594
 
586
- for (int i = 0; i < 4; i++) {
587
- if (nr[i] != 1) {
588
- break;
589
- }
590
- if (i > 0) {
591
- collapse_nb(cnb0, cne0);
592
- collapse_nb(cnb1, cne1);
593
- collapse(cne0);
594
- collapse(cne1);
595
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
596
+ for (int i = 0; i < 4; i++) {
597
+ if (nr[i] != 1) {
598
+ break;
599
+ }
600
+ if (i > 0) {
601
+ collapse_nb(cnb, cne);
602
+ collapse_nb(cnb0, cne0);
603
+ collapse_nb(cnb1, cne1);
604
+ collapse(cne);
605
+ collapse(cne0);
606
+ collapse(cne1);
607
+ }
595
608
  }
596
609
  }
597
610
  {
598
- int64_t ne0 = cne0[0];
599
- int64_t ne1 = cne0[1];
600
- int64_t ne2 = cne0[2];
601
- int64_t ne3 = cne0[3];
611
+ int64_t ne0 = cne[0];
612
+ int64_t ne1 = cne[1];
613
+ int64_t ne2 = cne[2];
614
+ int64_t ne3 = cne[3];
602
615
 
603
616
  int64_t ne10 = cne1[0];
604
617
  int64_t ne11 = cne1[1];
605
618
  int64_t ne12 = cne1[2];
606
619
  int64_t ne13 = cne1[3];
607
620
 
608
- size_t nb0 = cnb0[0];
609
- size_t nb1 = cnb0[1];
610
- size_t nb2 = cnb0[2];
611
- size_t nb3 = cnb0[3];
621
+ size_t nb0 = cnb[0];
622
+ size_t nb1 = cnb[1];
623
+ size_t nb2 = cnb[2];
624
+ size_t nb3 = cnb[3];
625
+
626
+ size_t nb00 = cnb0[0];
627
+ size_t nb01 = cnb0[1];
628
+ size_t nb02 = cnb0[2];
629
+ size_t nb03 = cnb0[3];
612
630
 
613
631
  size_t nb10 = cnb1[0];
614
632
  size_t nb11 = cnb1[1];
@@ -625,6 +643,28 @@ struct bin_bcast_sycl {
625
643
  size_t s12 = nb12 / sizeof(src1_t);
626
644
  size_t s13 = nb13 / sizeof(src1_t);
627
645
 
646
+ size_t s00 = nb00 / sizeof(src0_t);
647
+ size_t s01 = nb01 / sizeof(src0_t);
648
+ size_t s02 = nb02 / sizeof(src0_t);
649
+ size_t s03 = nb03 / sizeof(src0_t);
650
+
651
+ GGML_UNUSED(s00);
652
+
653
+ GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
654
+ GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
655
+ GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
656
+ GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
657
+
658
+ GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
659
+ GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
660
+ GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
661
+ GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
662
+
663
+ GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
664
+ GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
665
+ GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
666
+ GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
667
+
628
668
  GGML_ASSERT(s0 == 1);
629
669
  GGML_ASSERT(s10 == 1);
630
670
 
@@ -661,8 +701,8 @@ struct bin_bcast_sycl {
661
701
  [=](sycl::nd_item<3> item_ct1) {
662
702
  k_bin_bcast_unravel<bin_op>(
663
703
  src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
664
- ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
665
- s13, item_ct1);
704
+ ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
705
+ s03, s11, s12, s13, item_ct1);
666
706
  });
667
707
  }
668
708
  } else {
@@ -680,7 +720,7 @@ struct bin_bcast_sycl {
680
720
  [=](sycl::nd_item<3> item_ct1) {
681
721
  k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
682
722
  ne2, ne3, ne10, ne11, ne12, ne13,
683
- s1, s2, s3, s11, s12, s13,
723
+ s1, s2, s3, s01, s02, s03, s11, s12, s13,
684
724
  item_ct1);
685
725
  });
686
726
  }
@@ -138,7 +138,7 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
138
138
  stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
139
139
  sycl::range<3>(1, 1, WARP_SIZE),
140
140
  sycl::range<3>(1, 1, WARP_SIZE)),
141
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]]{
141
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
142
142
  dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
143
143
  });
144
144
 
@@ -210,7 +210,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
210
210
 
211
211
  stream->parallel_for(
212
212
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
213
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
213
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
214
214
  dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
215
215
  nrows, item_ct1);
216
216
  });
@@ -879,7 +879,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
879
879
 
880
880
  stream->parallel_for(
881
881
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
882
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
882
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
883
883
  dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
884
884
  vx, y, dst, ncols, nrows, item_ct1);
885
885
  });
@@ -902,7 +902,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
902
902
 
903
903
  stream->parallel_for(
904
904
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
905
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
905
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
906
906
  dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
907
907
  vx, y, dst, ncols, nrows, item_ct1);
908
908
  });
@@ -923,7 +923,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
923
923
 
924
924
  stream->parallel_for(
925
925
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
926
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
926
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
927
927
  dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
928
928
  vx, y, dst, ncols, nrows, item_ct1);
929
929
  });
@@ -944,7 +944,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
944
944
 
945
945
  stream->parallel_for(
946
946
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
947
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
947
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
948
948
  dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
949
949
  vx, y, dst, ncols, nrows, item_ct1);
950
950
  });
@@ -965,7 +965,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
965
965
 
966
966
  stream->parallel_for(
967
967
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
968
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
968
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
969
969
  dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
970
970
  vx, y, dst, ncols, nrows, item_ct1);
971
971
  });
@@ -986,7 +986,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
986
986
 
987
987
  stream->parallel_for(
988
988
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
989
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
989
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
990
990
  dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
991
991
  vx, y, dst, ncols, nrows, item_ct1);
992
992
  });
@@ -1004,7 +1004,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
1004
1004
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1005
1005
  stream->parallel_for(
1006
1006
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1007
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1007
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1008
1008
  dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
1009
1009
  });
1010
1010
  }
@@ -1020,7 +1020,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
1020
1020
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1021
1021
  stream->parallel_for(
1022
1022
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1023
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1023
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1024
1024
  dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
1025
1025
  });
1026
1026
  }
@@ -1036,7 +1036,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
1036
1036
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1037
1037
  stream->parallel_for(
1038
1038
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1039
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1039
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1040
1040
  dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
1041
1041
  });
1042
1042
  }
@@ -1049,7 +1049,7 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
1049
1049
  const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
1050
1050
  stream->parallel_for(
1051
1051
  sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
1052
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1052
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1053
1053
  dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
1054
1054
  });
1055
1055
  }
@@ -1065,7 +1065,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
1065
1065
  const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
1066
1066
  stream->parallel_for(
1067
1067
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
1068
- [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
1068
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
1069
1069
  dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
1070
1070
  });
1071
1071
  }
@@ -1143,7 +1143,6 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
1143
1143
  default:
1144
1144
  printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
1145
1145
  GGML_ABORT("fatal error");
1146
- break;
1147
1146
  }
1148
1147
 
1149
1148
  GGML_UNUSED(src1);