@fugood/llama.node 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/CMakeLists.txt +2 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +1 -1
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +8 -8
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +8 -9
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +4 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +43 -9
  25. package/src/llama.cpp/.github/workflows/docker.yml +3 -0
  26. package/src/llama.cpp/CMakeLists.txt +7 -4
  27. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  28. package/src/llama.cpp/common/CMakeLists.txt +0 -2
  29. package/src/llama.cpp/common/arg.cpp +642 -607
  30. package/src/llama.cpp/common/arg.h +22 -22
  31. package/src/llama.cpp/common/common.cpp +79 -281
  32. package/src/llama.cpp/common/common.h +130 -100
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  34. package/src/llama.cpp/common/log.cpp +50 -50
  35. package/src/llama.cpp/common/log.h +18 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  37. package/src/llama.cpp/common/ngram-cache.h +19 -19
  38. package/src/llama.cpp/common/sampling.cpp +116 -108
  39. package/src/llama.cpp/common/sampling.h +20 -20
  40. package/src/llama.cpp/docs/build.md +37 -17
  41. package/src/llama.cpp/examples/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +14 -14
  43. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  47. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  48. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  49. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  50. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  51. package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
  52. package/src/llama.cpp/examples/infill/infill.cpp +40 -86
  53. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
  54. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  55. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  56. package/src/llama.cpp/examples/llava/clip.cpp +1 -0
  57. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  58. package/src/llama.cpp/examples/llava/llava.cpp +37 -3
  59. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  60. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  61. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  62. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  63. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
  64. package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
  65. package/src/llama.cpp/examples/main/main.cpp +64 -109
  66. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  67. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  68. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  69. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  70. package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
  71. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  72. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
  73. package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
  74. package/src/llama.cpp/examples/server/server.cpp +553 -691
  75. package/src/llama.cpp/examples/server/utils.hpp +312 -25
  76. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  77. package/src/llama.cpp/examples/simple/simple.cpp +128 -96
  78. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  79. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  80. package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
  81. package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
  82. package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
  83. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  84. package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
  85. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  86. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  87. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  88. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  89. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  90. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  91. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  92. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  93. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  94. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  95. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  96. package/src/llama.cpp/ggml/include/ggml.h +53 -393
  97. package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
  98. package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
  99. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  100. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
  101. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  102. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  103. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  104. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  105. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  106. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
  107. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  108. package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
  109. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  110. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
  111. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  112. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
  113. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  114. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  115. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  116. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
  117. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  118. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  120. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  121. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
  122. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  123. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  124. package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
  125. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  126. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
  127. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  128. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  129. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  130. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  131. package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
  132. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  133. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  134. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
  135. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  136. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  137. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
  138. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
  141. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  142. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  143. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
  144. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
  145. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
  146. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  148. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  149. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  150. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  151. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  152. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  153. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  154. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  155. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
  156. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
  157. package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
  158. package/src/llama.cpp/include/llama.h +67 -33
  159. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  160. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  161. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  162. package/src/llama.cpp/src/llama-sampling.cpp +745 -105
  163. package/src/llama.cpp/src/llama-sampling.h +21 -2
  164. package/src/llama.cpp/src/llama-vocab.cpp +49 -9
  165. package/src/llama.cpp/src/llama-vocab.h +35 -11
  166. package/src/llama.cpp/src/llama.cpp +2636 -2406
  167. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  168. package/src/llama.cpp/tests/CMakeLists.txt +1 -2
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
  171. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  172. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  173. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  174. package/src/llama.cpp/tests/test-log.cpp +2 -2
  175. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  176. package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
  177. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  178. package/src/llama.cpp/tests/test-rope.cpp +1 -0
  179. package/src/llama.cpp/tests/test-sampling.cpp +162 -137
  180. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  181. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  182. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  183. package/src/llama.cpp/common/train.cpp +0 -1515
  184. package/src/llama.cpp/common/train.h +0 -233
  185. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  186. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  187. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  188. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  189. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
  190. /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
@@ -0,0 +1,81 @@
1
+ if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$")
2
+ message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD")
3
+ endif()
4
+
5
+ check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL)
6
+
7
+ if (DEFINED ENV{ONEAPI_ROOT})
8
+ message(STATUS "Using oneAPI Release SYCL compiler (icpx).")
9
+ elseif(SUPPORTS_SYCL)
10
+ message(WARNING "Using open-source SYCL compiler (clang++). Didn't detect ENV {ONEAPI_ROOT}.
11
+ If you expected the oneAPI Release compiler, please install oneAPI & source it, like:
12
+ source /opt/intel/oneapi/setvars.sh")
13
+ else()
14
+ message(FATAL_ERROR, "C++ compiler lacks SYCL support.")
15
+ endif()
16
+ message(STATUS "SYCL found")
17
+ #todo: AOT
18
+
19
+ add_library(ggml-sycl
20
+ ggml-sycl.cpp
21
+ ../../include/ggml-sycl.h)
22
+
23
+ target_link_libraries(ggml-sycl PRIVATE ggml-base)
24
+ target_include_directories(ggml-sycl PRIVATE . ..)
25
+
26
+ if (GGML_SYCL_F16)
27
+ if (GGML_SYCL_TARGET STREQUAL "AMD")
28
+ message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
29
+ endif()
30
+ add_compile_definitions(GGML_SYCL_F16)
31
+ endif()
32
+
33
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl")
34
+
35
+ if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
36
+ add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
37
+ elseif (GGML_SYCL_TARGET STREQUAL "AMD")
38
+ # INFO: Allowed Sub_group_sizes are not consistent through all
39
+ # hip targets. For example, 64 is used for certain models, but the backend
40
+ # does not support it.
41
+ # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32)
42
+ add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
43
+ else()
44
+ add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
45
+ endif()
46
+
47
+ file(GLOB GGML_HEADERS_SYCL "*.hpp")
48
+ file(GLOB GGML_SOURCES_SYCL "*.cpp")
49
+ target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
50
+
51
+ find_package(DNNL)
52
+ message("-- DNNL found:" ${DNNL_FOUND})
53
+
54
+ if (GGML_SYCL_TARGET STREQUAL "INTEL")
55
+ add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
56
+ else()
57
+ add_compile_definitions(GGML_SYCL_DNNL=0)
58
+ endif()
59
+
60
+ if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
61
+ target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
62
+ endif()
63
+
64
+ if (WIN32)
65
+ find_package(IntelSYCL REQUIRED)
66
+ find_package(MKL REQUIRED)
67
+ target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
68
+ else()
69
+ if (GGML_SYCL_TARGET STREQUAL "INTEL")
70
+ target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
71
+ elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
72
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
73
+ target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl)
74
+ elseif (GGML_SYCL_TARGET STREQUAL "AMD")
75
+ if (GGML_SYCL_HIP_TARGET STREQUAL "")
76
+ message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_HIP_TARGET has not been set.")
77
+ endif()
78
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=${GGML_SYCL_HIP_TARGET}")
79
+ target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl)
80
+ endif()
81
+ endif()
@@ -26,5 +26,8 @@
26
26
  #include "softmax.hpp"
27
27
  #include "tsembd.hpp"
28
28
  #include "im2col.hpp"
29
+ #include "wkv6.hpp"
30
+ #include "outprod.hpp"
31
+ #include "element_wise.hpp"
29
32
 
30
33
  #endif // GGML_SYCL_BACKEND_HPP
@@ -62,3 +62,43 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block
62
62
  }
63
63
  return sycl_down_blk_size;
64
64
  }
65
+
66
+ void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
67
+ const ggml_tensor *src1, ggml_tensor *dst,
68
+ const ggml_sycl_op_flatten_t op) try {
69
+ const int64_t nrows0 = ggml_nrows(src0);
70
+
71
+ const bool use_src1 = src1 != nullptr;
72
+ const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
73
+
74
+ GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
75
+ GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
76
+
77
+ ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
78
+ ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
79
+ ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
80
+
81
+ // dd = data device
82
+ float * src0_ddf = (float *) src0->data;
83
+ float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
84
+ float * dst_ddf = (float *) dst->data;
85
+
86
+ ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
87
+ ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
88
+ ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
89
+
90
+ ggml_sycl_set_device(ctx.device);
91
+ queue_ptr main_stream = ctx.stream();
92
+ // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
93
+ // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
94
+
95
+ // do the computation
96
+ op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
97
+ // print_ggml_tensor("tensor", dst);
98
+ }
99
+ catch (sycl::exception const &exc) {
100
+
101
+ std::cerr << exc.what() << "Exception caught at file:" << __FILE__
102
+ << ", line:" << __LINE__ << std::endl;
103
+ std::exit(1);
104
+ }
@@ -404,4 +404,262 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
404
404
 
405
405
  int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
406
406
 
407
+ typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
408
+ const ggml_tensor *src1,
409
+ ggml_tensor *dst, const float *src0_dd,
410
+ const float *src1_dd, float *dst_dd,
411
+ const queue_ptr &main_stream);
412
+
413
+ template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
414
+ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
415
+ int ne0, int ne1, int ne2, int ne3,
416
+ int ne10, int ne11, int ne12, int ne13,
417
+ /*int s0, */ int s1, int s2, int s3,
418
+ /*int s10,*/ int s11, int s12, int s13,
419
+ const sycl::nd_item<3> &item_ct1) {
420
+ const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
421
+ item_ct1.get_local_id(2);
422
+ const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
423
+ item_ct1.get_local_id(1));
424
+ const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
425
+ item_ct1.get_local_id(0)) /
426
+ ne3;
427
+ const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
428
+ item_ct1.get_local_id(0)) %
429
+ ne3;
430
+
431
+ if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
432
+ return;
433
+ }
434
+
435
+ const int i11 = i1 % ne11;
436
+ const int i12 = i2 % ne12;
437
+ const int i13 = i3 % ne13;
438
+
439
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
440
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
441
+ const size_t i_dst = i_src0;
442
+
443
+ const src0_t * src0_row = src0 + i_src0;
444
+ const src1_t * src1_row = src1 + i_src1;
445
+ dst_t * dst_row = dst + i_dst;
446
+
447
+ for (int i0 = i0s; i0 < ne0;
448
+ i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
449
+ const int i10 = i0 % ne10;
450
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
451
+ }
452
+ }
453
+
454
+ template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
455
+ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
456
+ int ne0, int ne1, int ne2, int ne3,
457
+ int ne10, int ne11, int ne12, int ne13,
458
+ /*int s0, */ int s1, int s2, int s3,
459
+ /*int s10,*/ int s11, int s12, int s13,
460
+ const sycl::nd_item<3> &item_ct1) {
461
+
462
+ const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
463
+ item_ct1.get_local_id(2);
464
+
465
+ const int i3 = i/(ne2*ne1*ne0);
466
+ const int i2 = (i/(ne1*ne0)) % ne2;
467
+ const int i1 = (i/ne0) % ne1;
468
+ const int i0 = i % ne0;
469
+
470
+ if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
471
+ return;
472
+ }
473
+
474
+ const int i11 = i1 % ne11;
475
+ const int i12 = i2 % ne12;
476
+ const int i13 = i3 % ne13;
477
+
478
+ const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
479
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
480
+ const size_t i_dst = i_src0;
481
+
482
+ const src0_t * src0_row = src0 + i_src0;
483
+ const src1_t * src1_row = src1 + i_src1;
484
+ dst_t * dst_row = dst + i_dst;
485
+
486
+ const int i10 = i0 % ne10;
487
+ dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
488
+ }
489
+
490
+
491
+ template<float (*bin_op)(const float, const float)>
492
+ struct bin_bcast_sycl {
493
+ template <typename src0_t, typename src1_t, typename dst_t>
494
+ void operator()(ggml_backend_sycl_context & ctx,
495
+ const struct ggml_tensor *src0,
496
+ const struct ggml_tensor *src1, struct ggml_tensor *dst,
497
+ const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
498
+ queue_ptr stream) {
499
+
500
+ GGML_TENSOR_BINARY_OP_LOCALS
501
+
502
+ int nr0 = ne10/ne0;
503
+ int nr1 = ne11/ne1;
504
+ int nr2 = ne12/ne2;
505
+ int nr3 = ne13/ne3;
506
+
507
+ int nr[4] = { nr0, nr1, nr2, nr3 };
508
+
509
+ // collapse dimensions until first broadcast dimension
510
+ int64_t cne0[] = {ne0, ne1, ne2, ne3};
511
+ int64_t cne1[] = {ne10, ne11, ne12, ne13};
512
+ size_t cnb0[] = {nb0, nb1, nb2, nb3};
513
+ size_t cnb1[] = {nb10, nb11, nb12, nb13};
514
+ auto collapse = [](int64_t cne[]) {
515
+ cne[0] *= cne[1];
516
+ cne[1] = cne[2];
517
+ cne[2] = cne[3];
518
+ cne[3] = 1;
519
+ };
520
+
521
+ auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
522
+ cnb[1] *= cne[1];
523
+ cnb[2] *= cne[2];
524
+ cnb[3] *= cne[3];
525
+ };
526
+
527
+ for (int i = 0; i < 4; i++) {
528
+ if (nr[i] != 1) {
529
+ break;
530
+ }
531
+ if (i > 0) {
532
+ collapse_nb(cnb0, cne0);
533
+ collapse_nb(cnb1, cne1);
534
+ collapse(cne0);
535
+ collapse(cne1);
536
+ }
537
+ }
538
+ {
539
+ int64_t ne0 = cne0[0];
540
+ int64_t ne1 = cne0[1];
541
+ int64_t ne2 = cne0[2];
542
+ int64_t ne3 = cne0[3];
543
+
544
+ int64_t ne10 = cne1[0];
545
+ int64_t ne11 = cne1[1];
546
+ int64_t ne12 = cne1[2];
547
+ int64_t ne13 = cne1[3];
548
+
549
+ size_t nb0 = cnb0[0];
550
+ size_t nb1 = cnb0[1];
551
+ size_t nb2 = cnb0[2];
552
+ size_t nb3 = cnb0[3];
553
+
554
+ size_t nb10 = cnb1[0];
555
+ size_t nb11 = cnb1[1];
556
+ size_t nb12 = cnb1[2];
557
+ size_t nb13 = cnb1[3];
558
+
559
+ size_t s0 = nb0 / sizeof(dst_t);
560
+ size_t s1 = nb1 / sizeof(dst_t);
561
+ size_t s2 = nb2 / sizeof(dst_t);
562
+ size_t s3 = nb3 / sizeof(dst_t);
563
+
564
+ size_t s10 = nb10 / sizeof(src1_t);
565
+ size_t s11 = nb11 / sizeof(src1_t);
566
+ size_t s12 = nb12 / sizeof(src1_t);
567
+ size_t s13 = nb13 / sizeof(src1_t);
568
+
569
+ GGML_ASSERT(s0 == 1);
570
+ GGML_ASSERT(s10 == 1);
571
+
572
+ const int block_size = 128;
573
+
574
+ int64_t hne0 = std::max(ne0/2LL, 1LL);
575
+
576
+ sycl::range<3> block_dims(1, 1, 1);
577
+ block_dims[2] = std::min<unsigned int>(hne0, block_size);
578
+ block_dims[1] = std::min<unsigned int>(
579
+ ne1, block_size / (unsigned int)block_dims[2]);
580
+ block_dims[0] = std::min(
581
+ std::min<unsigned int>(
582
+ ne2 * ne3, block_size / (unsigned int)block_dims[2] /
583
+ (unsigned int)block_dims[1]),
584
+ 64U);
585
+
586
+ sycl::range<3> block_nums(
587
+ (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
588
+ (ne1 + block_dims[1] - 1) / block_dims[1],
589
+ (hne0 + block_dims[2] - 1) / block_dims[2]);
590
+
591
+ if (block_nums[0] > 65535) {
592
+ // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
593
+ int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
594
+ {
595
+ dpct::has_capability_or_fail(stream->get_device(),
596
+ {sycl::aspect::fp16});
597
+
598
+ stream->parallel_for(
599
+ sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
600
+ sycl::range<3>(1, 1, block_size),
601
+ sycl::range<3>(1, 1, block_size)),
602
+ [=](sycl::nd_item<3> item_ct1) {
603
+ k_bin_bcast_unravel<bin_op>(
604
+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
605
+ ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
606
+ s13, item_ct1);
607
+ });
608
+ }
609
+ } else {
610
+ /*
611
+ DPCT1049:16: The work-group size passed to the SYCL kernel may
612
+ exceed the limit. To get the device limit, query
613
+ info::device::max_work_group_size. Adjust the work-group size if
614
+ needed.
615
+ */
616
+ dpct::has_capability_or_fail(stream->get_device(),
617
+ {sycl::aspect::fp16});
618
+
619
+ stream->parallel_for(
620
+ sycl::nd_range<3>(block_nums * block_dims, block_dims),
621
+ [=](sycl::nd_item<3> item_ct1) {
622
+ k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
623
+ ne2, ne3, ne10, ne11, ne12, ne13,
624
+ s1, s2, s3, s11, s12, s13,
625
+ item_ct1);
626
+ });
627
+ }
628
+ }
629
+ }
630
+ };
631
+
632
+ template <class op>
633
+ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
634
+ const ggml_tensor *src1, ggml_tensor *dst,
635
+ const float *src0_dd, const float *src1_dd,
636
+ float *dst_dd,
637
+ const queue_ptr &main_stream) {
638
+
639
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
640
+ op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
641
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
642
+ op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
643
+ (sycl::half *)dst_dd, main_stream);
644
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
645
+ op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
646
+ main_stream);
647
+ } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
648
+ op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
649
+ main_stream);
650
+ } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
651
+ op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
652
+ main_stream);
653
+ } else {
654
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
655
+ ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
656
+ GGML_ABORT("fatal error");
657
+ }
658
+ }
659
+
660
+
661
+ void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
662
+ const ggml_tensor *src1, ggml_tensor *dst,
663
+ const ggml_sycl_op_flatten_t op);
664
+
407
665
  #endif // GGML_SYCL_COMMON_HPP
@@ -106,6 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
106
106
  concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
107
107
  });
108
108
  break;
109
+ // dim >=2 will be dispatched to the default path
109
110
  default:
110
111
  stream->parallel_for(
111
112
  sycl::nd_range<3>(gridDim *
@@ -15,6 +15,7 @@
15
15
 
16
16
  #include <sycl/sycl.hpp>
17
17
  #include <sycl/half_type.hpp>
18
+ #include <syclcompat/math.hpp>
18
19
  #include <oneapi/mkl.hpp>
19
20
  #include <map>
20
21
 
@@ -1830,31 +1831,10 @@ namespace dpct
1830
1831
  : id);
1831
1832
  }
1832
1833
 
1833
- template <typename T>
1834
- sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val)
1835
- {
1836
- return sycl::vec<T, 1>(val)
1837
- .template as<sycl::vec<
1838
- std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>, 4>>()
1839
- .template convert<T>();
1840
- }
1841
-
1842
- template <typename T1, typename T2>
1843
- using dot_product_acc_t =
1844
- std::conditional_t<std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
1845
- uint32_t, int32_t>;
1846
-
1847
1834
  template <typename T1, typename T2, typename T3>
1848
1835
  inline auto dp4a(T1 a, T2 b, T3 c)
1849
1836
  {
1850
- dot_product_acc_t<T1, T2> res = c;
1851
- auto va = extract_and_sign_or_zero_extend4(a);
1852
- auto vb = extract_and_sign_or_zero_extend4(b);
1853
- res += va[0] * vb[0];
1854
- res += va[1] * vb[1];
1855
- res += va[2] * vb[2];
1856
- res += va[3] * vb[3];
1857
- return res;
1837
+ return syclcompat::dp4a(a, b, c);
1858
1838
  }
1859
1839
 
1860
1840
  struct sub_sat