@fugood/llama.node 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -8
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +4 -2
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +10 -10
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +14 -17
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +5 -4
- package/src/llama.cpp/.github/workflows/build.yml +137 -29
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +46 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +26 -11
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +10 -10
- package/src/llama.cpp/common/arg.cpp +2041 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +523 -1861
- package/src/llama.cpp/common/common.h +234 -106
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +39 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +356 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/docs/build.md +72 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +49 -65
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
- package/src/llama.cpp/examples/infill/infill.cpp +131 -192
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +686 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
- package/src/llama.cpp/examples/llava/llava.cpp +146 -26
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
- package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
- package/src/llama.cpp/examples/main/main.cpp +216 -313
- package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
- package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
- package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
- package/src/llama.cpp/examples/server/server.cpp +1347 -1531
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +396 -107
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +132 -106
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
- package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +272 -505
- package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
- package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
- package/src/llama.cpp/include/llama.h +296 -285
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
- package/src/llama.cpp/src/llama-sampling.h +39 -47
- package/src/llama.cpp/src/llama-vocab.cpp +390 -127
- package/src/llama.cpp/src/llama-vocab.h +60 -20
- package/src/llama.cpp/src/llama.cpp +6215 -3263
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +4 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
- package/src/llama.cpp/tests/test-barrier.cpp +94 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +2 -1
- package/src/llama.cpp/tests/test-sampling.cpp +226 -142
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/common/train.cpp +0 -1513
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$")
|
|
2
|
+
message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD")
|
|
3
|
+
endif()
|
|
4
|
+
|
|
5
|
+
check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL)
|
|
6
|
+
|
|
7
|
+
if (DEFINED ENV{ONEAPI_ROOT})
|
|
8
|
+
message(STATUS "Using oneAPI Release SYCL compiler (icpx).")
|
|
9
|
+
elseif(SUPPORTS_SYCL)
|
|
10
|
+
message(WARNING "Using open-source SYCL compiler (clang++). Didn't detect ENV {ONEAPI_ROOT}.
|
|
11
|
+
If you expected the oneAPI Release compiler, please install oneAPI & source it, like:
|
|
12
|
+
source /opt/intel/oneapi/setvars.sh")
|
|
13
|
+
else()
|
|
14
|
+
message(FATAL_ERROR, "C++ compiler lacks SYCL support.")
|
|
15
|
+
endif()
|
|
16
|
+
message(STATUS "SYCL found")
|
|
17
|
+
#todo: AOT
|
|
18
|
+
|
|
19
|
+
add_library(ggml-sycl
|
|
20
|
+
ggml-sycl.cpp
|
|
21
|
+
../../include/ggml-sycl.h)
|
|
22
|
+
|
|
23
|
+
target_link_libraries(ggml-sycl PRIVATE ggml-base)
|
|
24
|
+
target_include_directories(ggml-sycl PRIVATE . ..)
|
|
25
|
+
|
|
26
|
+
if (GGML_SYCL_F16)
|
|
27
|
+
if (GGML_SYCL_TARGET STREQUAL "AMD")
|
|
28
|
+
message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
|
|
29
|
+
endif()
|
|
30
|
+
add_compile_definitions(GGML_SYCL_F16)
|
|
31
|
+
endif()
|
|
32
|
+
|
|
33
|
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl")
|
|
34
|
+
|
|
35
|
+
if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
|
36
|
+
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
|
|
37
|
+
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
|
38
|
+
# INFO: Allowed Sub_group_sizes are not consistent through all
|
|
39
|
+
# hip targets. For example, 64 is used for certain models, but the backend
|
|
40
|
+
# does not support it.
|
|
41
|
+
# Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32)
|
|
42
|
+
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
|
|
43
|
+
else()
|
|
44
|
+
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
|
|
45
|
+
endif()
|
|
46
|
+
|
|
47
|
+
file(GLOB GGML_HEADERS_SYCL "*.hpp")
|
|
48
|
+
file(GLOB GGML_SOURCES_SYCL "*.cpp")
|
|
49
|
+
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
|
|
50
|
+
|
|
51
|
+
find_package(DNNL)
|
|
52
|
+
message("-- DNNL found:" ${DNNL_FOUND})
|
|
53
|
+
|
|
54
|
+
if (GGML_SYCL_TARGET STREQUAL "INTEL")
|
|
55
|
+
add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
|
|
56
|
+
else()
|
|
57
|
+
add_compile_definitions(GGML_SYCL_DNNL=0)
|
|
58
|
+
endif()
|
|
59
|
+
|
|
60
|
+
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
|
|
61
|
+
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
|
|
62
|
+
endif()
|
|
63
|
+
|
|
64
|
+
if (WIN32)
|
|
65
|
+
find_package(IntelSYCL REQUIRED)
|
|
66
|
+
find_package(MKL REQUIRED)
|
|
67
|
+
target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
|
|
68
|
+
else()
|
|
69
|
+
if (GGML_SYCL_TARGET STREQUAL "INTEL")
|
|
70
|
+
target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
|
|
71
|
+
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
|
|
72
|
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
|
|
73
|
+
target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl)
|
|
74
|
+
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
|
|
75
|
+
if (GGML_SYCL_HIP_TARGET STREQUAL "")
|
|
76
|
+
message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_HIP_TARGET has not been set.")
|
|
77
|
+
endif()
|
|
78
|
+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=${GGML_SYCL_HIP_TARGET}")
|
|
79
|
+
target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl)
|
|
80
|
+
endif()
|
|
81
|
+
endif()
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
|
|
16
16
|
#include "concat.hpp"
|
|
17
17
|
#include "common.hpp"
|
|
18
|
+
#include "conv.hpp"
|
|
18
19
|
#include "convert.hpp"
|
|
19
20
|
#include "dequantize.hpp"
|
|
20
21
|
#include "dmmv.hpp"
|
|
@@ -23,5 +24,10 @@
|
|
|
23
24
|
#include "rope.hpp"
|
|
24
25
|
#include "norm.hpp"
|
|
25
26
|
#include "softmax.hpp"
|
|
27
|
+
#include "tsembd.hpp"
|
|
28
|
+
#include "im2col.hpp"
|
|
29
|
+
#include "wkv6.hpp"
|
|
30
|
+
#include "outprod.hpp"
|
|
31
|
+
#include "element_wise.hpp"
|
|
26
32
|
|
|
27
33
|
#endif // GGML_SYCL_BACKEND_HPP
|
|
@@ -51,3 +51,54 @@ void ggml_sycl_host_free(void* ptr) try {
|
|
|
51
51
|
<< ", line:" << __LINE__ << std::endl;
|
|
52
52
|
std::exit(1);
|
|
53
53
|
}
|
|
54
|
+
|
|
55
|
+
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
|
|
56
|
+
const int64_t max_range = std::numeric_limits<int>::max();
|
|
57
|
+
int64_t sycl_down_blk_size = block_size;
|
|
58
|
+
int64_t global_range = accumulate_block_num * sycl_down_blk_size;
|
|
59
|
+
while(global_range > max_range) {
|
|
60
|
+
sycl_down_blk_size /= 2;
|
|
61
|
+
global_range = accumulate_block_num * sycl_down_blk_size;
|
|
62
|
+
}
|
|
63
|
+
return sycl_down_blk_size;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
67
|
+
const ggml_tensor *src1, ggml_tensor *dst,
|
|
68
|
+
const ggml_sycl_op_flatten_t op) try {
|
|
69
|
+
const int64_t nrows0 = ggml_nrows(src0);
|
|
70
|
+
|
|
71
|
+
const bool use_src1 = src1 != nullptr;
|
|
72
|
+
const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
|
|
73
|
+
|
|
74
|
+
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
|
75
|
+
GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
|
76
|
+
|
|
77
|
+
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
|
78
|
+
ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
|
|
79
|
+
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
|
80
|
+
|
|
81
|
+
// dd = data device
|
|
82
|
+
float * src0_ddf = (float *) src0->data;
|
|
83
|
+
float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
|
|
84
|
+
float * dst_ddf = (float *) dst->data;
|
|
85
|
+
|
|
86
|
+
ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
|
|
87
|
+
ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
|
|
88
|
+
ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
|
|
89
|
+
|
|
90
|
+
ggml_sycl_set_device(ctx.device);
|
|
91
|
+
queue_ptr main_stream = ctx.stream();
|
|
92
|
+
// GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
|
|
93
|
+
// ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
|
|
94
|
+
|
|
95
|
+
// do the computation
|
|
96
|
+
op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
|
|
97
|
+
// print_ggml_tensor("tensor", dst);
|
|
98
|
+
}
|
|
99
|
+
catch (sycl::exception const &exc) {
|
|
100
|
+
|
|
101
|
+
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
|
102
|
+
<< ", line:" << __LINE__ << std::endl;
|
|
103
|
+
std::exit(1);
|
|
104
|
+
}
|
|
@@ -19,6 +19,10 @@
|
|
|
19
19
|
#include "dpct/helper.hpp"
|
|
20
20
|
#include "ggml-sycl.h"
|
|
21
21
|
#include "presets.hpp"
|
|
22
|
+
#if GGML_SYCL_DNNL
|
|
23
|
+
#include "dnnl.hpp"
|
|
24
|
+
#include "dnnl_sycl.hpp"
|
|
25
|
+
#endif
|
|
22
26
|
|
|
23
27
|
#define GGML_COMMON_DECL_SYCL
|
|
24
28
|
#define GGML_COMMON_IMPL_SYCL
|
|
@@ -276,6 +280,52 @@ struct ggml_backend_sycl_context {
|
|
|
276
280
|
return stream(device, 0);
|
|
277
281
|
}
|
|
278
282
|
|
|
283
|
+
#if GGML_SYCL_DNNL
|
|
284
|
+
dnnl::engine make_engine(sycl::queue* q) {
|
|
285
|
+
// Get the device associated with the queue
|
|
286
|
+
sycl::device dev = q->get_device();
|
|
287
|
+
// Get the context associated with the queue
|
|
288
|
+
sycl::context ctx = q->get_context();
|
|
289
|
+
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
|
|
290
|
+
return eng;
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
|
|
294
|
+
std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
|
|
295
|
+
dnnl::stream stream_dnnl(int device, int _stream) {
|
|
296
|
+
auto q = stream(device, _stream);
|
|
297
|
+
return stream_dnnl(q);
|
|
298
|
+
}
|
|
299
|
+
dnnl::engine engine_dnnl(sycl::queue* qptr) {
|
|
300
|
+
auto it = engine_map.find(qptr);
|
|
301
|
+
if (it == engine_map.end()) {
|
|
302
|
+
auto eng = make_engine(qptr);
|
|
303
|
+
engine_map[qptr] = eng;
|
|
304
|
+
return eng;
|
|
305
|
+
}
|
|
306
|
+
else
|
|
307
|
+
{
|
|
308
|
+
return it->second;
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
dnnl::stream stream_dnnl(sycl::queue* qptr) {
|
|
312
|
+
auto it = stream_map.find(qptr);
|
|
313
|
+
if (it == stream_map.end()) {
|
|
314
|
+
auto eng = engine_dnnl(qptr);
|
|
315
|
+
auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
|
|
316
|
+
stream_map[qptr] = stream;
|
|
317
|
+
return stream;
|
|
318
|
+
}
|
|
319
|
+
else
|
|
320
|
+
{
|
|
321
|
+
return it->second;
|
|
322
|
+
}
|
|
323
|
+
}
|
|
324
|
+
dnnl::stream stream_dnnl() {
|
|
325
|
+
return stream_dnnl(device, 0);
|
|
326
|
+
}
|
|
327
|
+
#endif
|
|
328
|
+
|
|
279
329
|
// pool
|
|
280
330
|
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
|
281
331
|
|
|
@@ -352,4 +402,264 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
|
|
|
352
402
|
return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
|
|
353
403
|
}
|
|
354
404
|
|
|
405
|
+
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
|
|
406
|
+
|
|
407
|
+
typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
408
|
+
const ggml_tensor *src1,
|
|
409
|
+
ggml_tensor *dst, const float *src0_dd,
|
|
410
|
+
const float *src1_dd, float *dst_dd,
|
|
411
|
+
const queue_ptr &main_stream);
|
|
412
|
+
|
|
413
|
+
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
|
414
|
+
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
|
415
|
+
int ne0, int ne1, int ne2, int ne3,
|
|
416
|
+
int ne10, int ne11, int ne12, int ne13,
|
|
417
|
+
/*int s0, */ int s1, int s2, int s3,
|
|
418
|
+
/*int s10,*/ int s11, int s12, int s13,
|
|
419
|
+
const sycl::nd_item<3> &item_ct1) {
|
|
420
|
+
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
421
|
+
item_ct1.get_local_id(2);
|
|
422
|
+
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
423
|
+
item_ct1.get_local_id(1));
|
|
424
|
+
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
|
425
|
+
item_ct1.get_local_id(0)) /
|
|
426
|
+
ne3;
|
|
427
|
+
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
|
428
|
+
item_ct1.get_local_id(0)) %
|
|
429
|
+
ne3;
|
|
430
|
+
|
|
431
|
+
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
|
432
|
+
return;
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
const int i11 = i1 % ne11;
|
|
436
|
+
const int i12 = i2 % ne12;
|
|
437
|
+
const int i13 = i3 % ne13;
|
|
438
|
+
|
|
439
|
+
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
|
|
440
|
+
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
|
441
|
+
const size_t i_dst = i_src0;
|
|
442
|
+
|
|
443
|
+
const src0_t * src0_row = src0 + i_src0;
|
|
444
|
+
const src1_t * src1_row = src1 + i_src1;
|
|
445
|
+
dst_t * dst_row = dst + i_dst;
|
|
446
|
+
|
|
447
|
+
for (int i0 = i0s; i0 < ne0;
|
|
448
|
+
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
|
|
449
|
+
const int i10 = i0 % ne10;
|
|
450
|
+
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
|
451
|
+
}
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
|
455
|
+
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
|
456
|
+
int ne0, int ne1, int ne2, int ne3,
|
|
457
|
+
int ne10, int ne11, int ne12, int ne13,
|
|
458
|
+
/*int s0, */ int s1, int s2, int s3,
|
|
459
|
+
/*int s10,*/ int s11, int s12, int s13,
|
|
460
|
+
const sycl::nd_item<3> &item_ct1) {
|
|
461
|
+
|
|
462
|
+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
463
|
+
item_ct1.get_local_id(2);
|
|
464
|
+
|
|
465
|
+
const int i3 = i/(ne2*ne1*ne0);
|
|
466
|
+
const int i2 = (i/(ne1*ne0)) % ne2;
|
|
467
|
+
const int i1 = (i/ne0) % ne1;
|
|
468
|
+
const int i0 = i % ne0;
|
|
469
|
+
|
|
470
|
+
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
|
471
|
+
return;
|
|
472
|
+
}
|
|
473
|
+
|
|
474
|
+
const int i11 = i1 % ne11;
|
|
475
|
+
const int i12 = i2 % ne12;
|
|
476
|
+
const int i13 = i3 % ne13;
|
|
477
|
+
|
|
478
|
+
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
|
|
479
|
+
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
|
480
|
+
const size_t i_dst = i_src0;
|
|
481
|
+
|
|
482
|
+
const src0_t * src0_row = src0 + i_src0;
|
|
483
|
+
const src1_t * src1_row = src1 + i_src1;
|
|
484
|
+
dst_t * dst_row = dst + i_dst;
|
|
485
|
+
|
|
486
|
+
const int i10 = i0 % ne10;
|
|
487
|
+
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
template<float (*bin_op)(const float, const float)>
|
|
492
|
+
struct bin_bcast_sycl {
|
|
493
|
+
template <typename src0_t, typename src1_t, typename dst_t>
|
|
494
|
+
void operator()(ggml_backend_sycl_context & ctx,
|
|
495
|
+
const struct ggml_tensor *src0,
|
|
496
|
+
const struct ggml_tensor *src1, struct ggml_tensor *dst,
|
|
497
|
+
const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
|
|
498
|
+
queue_ptr stream) {
|
|
499
|
+
|
|
500
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
501
|
+
|
|
502
|
+
int nr0 = ne10/ne0;
|
|
503
|
+
int nr1 = ne11/ne1;
|
|
504
|
+
int nr2 = ne12/ne2;
|
|
505
|
+
int nr3 = ne13/ne3;
|
|
506
|
+
|
|
507
|
+
int nr[4] = { nr0, nr1, nr2, nr3 };
|
|
508
|
+
|
|
509
|
+
// collapse dimensions until first broadcast dimension
|
|
510
|
+
int64_t cne0[] = {ne0, ne1, ne2, ne3};
|
|
511
|
+
int64_t cne1[] = {ne10, ne11, ne12, ne13};
|
|
512
|
+
size_t cnb0[] = {nb0, nb1, nb2, nb3};
|
|
513
|
+
size_t cnb1[] = {nb10, nb11, nb12, nb13};
|
|
514
|
+
auto collapse = [](int64_t cne[]) {
|
|
515
|
+
cne[0] *= cne[1];
|
|
516
|
+
cne[1] = cne[2];
|
|
517
|
+
cne[2] = cne[3];
|
|
518
|
+
cne[3] = 1;
|
|
519
|
+
};
|
|
520
|
+
|
|
521
|
+
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
|
|
522
|
+
cnb[1] *= cne[1];
|
|
523
|
+
cnb[2] *= cne[2];
|
|
524
|
+
cnb[3] *= cne[3];
|
|
525
|
+
};
|
|
526
|
+
|
|
527
|
+
for (int i = 0; i < 4; i++) {
|
|
528
|
+
if (nr[i] != 1) {
|
|
529
|
+
break;
|
|
530
|
+
}
|
|
531
|
+
if (i > 0) {
|
|
532
|
+
collapse_nb(cnb0, cne0);
|
|
533
|
+
collapse_nb(cnb1, cne1);
|
|
534
|
+
collapse(cne0);
|
|
535
|
+
collapse(cne1);
|
|
536
|
+
}
|
|
537
|
+
}
|
|
538
|
+
{
|
|
539
|
+
int64_t ne0 = cne0[0];
|
|
540
|
+
int64_t ne1 = cne0[1];
|
|
541
|
+
int64_t ne2 = cne0[2];
|
|
542
|
+
int64_t ne3 = cne0[3];
|
|
543
|
+
|
|
544
|
+
int64_t ne10 = cne1[0];
|
|
545
|
+
int64_t ne11 = cne1[1];
|
|
546
|
+
int64_t ne12 = cne1[2];
|
|
547
|
+
int64_t ne13 = cne1[3];
|
|
548
|
+
|
|
549
|
+
size_t nb0 = cnb0[0];
|
|
550
|
+
size_t nb1 = cnb0[1];
|
|
551
|
+
size_t nb2 = cnb0[2];
|
|
552
|
+
size_t nb3 = cnb0[3];
|
|
553
|
+
|
|
554
|
+
size_t nb10 = cnb1[0];
|
|
555
|
+
size_t nb11 = cnb1[1];
|
|
556
|
+
size_t nb12 = cnb1[2];
|
|
557
|
+
size_t nb13 = cnb1[3];
|
|
558
|
+
|
|
559
|
+
size_t s0 = nb0 / sizeof(dst_t);
|
|
560
|
+
size_t s1 = nb1 / sizeof(dst_t);
|
|
561
|
+
size_t s2 = nb2 / sizeof(dst_t);
|
|
562
|
+
size_t s3 = nb3 / sizeof(dst_t);
|
|
563
|
+
|
|
564
|
+
size_t s10 = nb10 / sizeof(src1_t);
|
|
565
|
+
size_t s11 = nb11 / sizeof(src1_t);
|
|
566
|
+
size_t s12 = nb12 / sizeof(src1_t);
|
|
567
|
+
size_t s13 = nb13 / sizeof(src1_t);
|
|
568
|
+
|
|
569
|
+
GGML_ASSERT(s0 == 1);
|
|
570
|
+
GGML_ASSERT(s10 == 1);
|
|
571
|
+
|
|
572
|
+
const int block_size = 128;
|
|
573
|
+
|
|
574
|
+
int64_t hne0 = std::max(ne0/2LL, 1LL);
|
|
575
|
+
|
|
576
|
+
sycl::range<3> block_dims(1, 1, 1);
|
|
577
|
+
block_dims[2] = std::min<unsigned int>(hne0, block_size);
|
|
578
|
+
block_dims[1] = std::min<unsigned int>(
|
|
579
|
+
ne1, block_size / (unsigned int)block_dims[2]);
|
|
580
|
+
block_dims[0] = std::min(
|
|
581
|
+
std::min<unsigned int>(
|
|
582
|
+
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
|
|
583
|
+
(unsigned int)block_dims[1]),
|
|
584
|
+
64U);
|
|
585
|
+
|
|
586
|
+
sycl::range<3> block_nums(
|
|
587
|
+
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
|
|
588
|
+
(ne1 + block_dims[1] - 1) / block_dims[1],
|
|
589
|
+
(hne0 + block_dims[2] - 1) / block_dims[2]);
|
|
590
|
+
|
|
591
|
+
if (block_nums[0] > 65535) {
|
|
592
|
+
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
|
|
593
|
+
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
|
|
594
|
+
{
|
|
595
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
|
596
|
+
{sycl::aspect::fp16});
|
|
597
|
+
|
|
598
|
+
stream->parallel_for(
|
|
599
|
+
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
|
|
600
|
+
sycl::range<3>(1, 1, block_size),
|
|
601
|
+
sycl::range<3>(1, 1, block_size)),
|
|
602
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
603
|
+
k_bin_bcast_unravel<bin_op>(
|
|
604
|
+
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
|
|
605
|
+
ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
|
|
606
|
+
s13, item_ct1);
|
|
607
|
+
});
|
|
608
|
+
}
|
|
609
|
+
} else {
|
|
610
|
+
/*
|
|
611
|
+
DPCT1049:16: The work-group size passed to the SYCL kernel may
|
|
612
|
+
exceed the limit. To get the device limit, query
|
|
613
|
+
info::device::max_work_group_size. Adjust the work-group size if
|
|
614
|
+
needed.
|
|
615
|
+
*/
|
|
616
|
+
dpct::has_capability_or_fail(stream->get_device(),
|
|
617
|
+
{sycl::aspect::fp16});
|
|
618
|
+
|
|
619
|
+
stream->parallel_for(
|
|
620
|
+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
621
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
622
|
+
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
|
623
|
+
ne2, ne3, ne10, ne11, ne12, ne13,
|
|
624
|
+
s1, s2, s3, s11, s12, s13,
|
|
625
|
+
item_ct1);
|
|
626
|
+
});
|
|
627
|
+
}
|
|
628
|
+
}
|
|
629
|
+
}
|
|
630
|
+
};
|
|
631
|
+
|
|
632
|
+
template <class op>
|
|
633
|
+
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
634
|
+
const ggml_tensor *src1, ggml_tensor *dst,
|
|
635
|
+
const float *src0_dd, const float *src1_dd,
|
|
636
|
+
float *dst_dd,
|
|
637
|
+
const queue_ptr &main_stream) {
|
|
638
|
+
|
|
639
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
640
|
+
op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
|
641
|
+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
|
642
|
+
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
|
|
643
|
+
(sycl::half *)dst_dd, main_stream);
|
|
644
|
+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
|
645
|
+
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
|
|
646
|
+
main_stream);
|
|
647
|
+
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
|
|
648
|
+
op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
|
|
649
|
+
main_stream);
|
|
650
|
+
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
|
|
651
|
+
op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
|
|
652
|
+
main_stream);
|
|
653
|
+
} else {
|
|
654
|
+
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
|
|
655
|
+
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
|
|
656
|
+
GGML_ABORT("fatal error");
|
|
657
|
+
}
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
662
|
+
const ggml_tensor *src1, ggml_tensor *dst,
|
|
663
|
+
const ggml_sycl_op_flatten_t op);
|
|
664
|
+
|
|
355
665
|
#endif // GGML_SYCL_COMMON_HPP
|
|
@@ -106,6 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
|
|
|
106
106
|
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
|
|
107
107
|
});
|
|
108
108
|
break;
|
|
109
|
+
// dim >=2 will be dispatched to the default path
|
|
109
110
|
default:
|
|
110
111
|
stream->parallel_for(
|
|
111
112
|
sycl::nd_range<3>(gridDim *
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
//
|
|
2
|
+
// MIT license
|
|
3
|
+
// Copyright (C) 2024 Intel Corporation
|
|
4
|
+
// SPDX-License-Identifier: MIT
|
|
5
|
+
//
|
|
6
|
+
|
|
7
|
+
//
|
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
11
|
+
//
|
|
12
|
+
|
|
13
|
+
#include "conv.hpp"
|
|
14
|
+
|
|
15
|
+
static void conv_transpose_1d_kernel(
|
|
16
|
+
const int s0, const int output_size,
|
|
17
|
+
const int src0_ne0, const int src0_ne1, const int src0_ne2,
|
|
18
|
+
const int src1_ne0, const int dst_ne0,
|
|
19
|
+
const float * src0, const float * src1, float * dst,
|
|
20
|
+
const sycl::nd_item<3> &item_ct1) {
|
|
21
|
+
int global_index = item_ct1.get_local_id(2) +
|
|
22
|
+
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
|
23
|
+
if (global_index >= output_size) {
|
|
24
|
+
return;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
int out_index = global_index / dst_ne0;
|
|
28
|
+
|
|
29
|
+
float accumulator = 0;
|
|
30
|
+
|
|
31
|
+
for (int c = 0; c < src0_ne2; c++) {
|
|
32
|
+
int idx = global_index % dst_ne0;
|
|
33
|
+
|
|
34
|
+
int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
|
|
35
|
+
int input_offset = src1_ne0 * c;
|
|
36
|
+
|
|
37
|
+
for (int i = 0; i < src1_ne0; i++) {
|
|
38
|
+
if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
|
|
39
|
+
continue;
|
|
40
|
+
}
|
|
41
|
+
int weight_idx = idx - i*s0;
|
|
42
|
+
|
|
43
|
+
float kernel_weight = src0[kernel_offset + weight_idx];
|
|
44
|
+
float input_value = src1[input_offset+i];
|
|
45
|
+
|
|
46
|
+
accumulator += kernel_weight * input_value;
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
dst[global_index] = accumulator;
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
static void conv_transpose_1d_f32_f32_sycl(
|
|
53
|
+
const int s0, const int output_size,
|
|
54
|
+
const int src0_ne0, const int src0_ne1, const int src0_ne2,
|
|
55
|
+
const int src1_ne0, const int dst_ne0,
|
|
56
|
+
const float *src0, const float *src1, float *dst,
|
|
57
|
+
const queue_ptr& stream) {
|
|
58
|
+
|
|
59
|
+
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
|
|
60
|
+
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
|
|
61
|
+
const sycl::range<3> block_nums(1, 1, num_blocks);
|
|
62
|
+
stream->parallel_for(
|
|
63
|
+
sycl::nd_range<3>(
|
|
64
|
+
block_nums * block_dims, block_dims),
|
|
65
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
66
|
+
conv_transpose_1d_kernel(
|
|
67
|
+
s0, output_size,
|
|
68
|
+
src0_ne0, src0_ne1, src0_ne2,
|
|
69
|
+
src1_ne0, dst_ne0,
|
|
70
|
+
src0, src1, dst, item_ct1);
|
|
71
|
+
});
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
75
|
+
const ggml_tensor *src1, ggml_tensor *dst) {
|
|
76
|
+
const float * src0_d = (const float *)src0->data;
|
|
77
|
+
const float * src1_d = (const float *)src1->data;
|
|
78
|
+
|
|
79
|
+
float * dst_d = (float *)dst->data;
|
|
80
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
81
|
+
|
|
82
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
83
|
+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
84
|
+
|
|
85
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
86
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
|
87
|
+
|
|
88
|
+
const int32_t * opts = (const int32_t *)dst->op_params;
|
|
89
|
+
|
|
90
|
+
const int s0 = opts[0];
|
|
91
|
+
|
|
92
|
+
const int64_t output_size = ggml_nelements(dst);
|
|
93
|
+
|
|
94
|
+
conv_transpose_1d_f32_f32_sycl(s0, output_size,
|
|
95
|
+
src0->ne[0], src0->ne[1], src0->ne[2],
|
|
96
|
+
src1->ne[0], dst->ne[0],
|
|
97
|
+
src0_d, src1_d, dst_d, stream);
|
|
98
|
+
}
|
|
99
|
+
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
//
|
|
2
|
+
// MIT license
|
|
3
|
+
// Copyright (C) 2024 Intel Corporation
|
|
4
|
+
// SPDX-License-Identifier: MIT
|
|
5
|
+
//
|
|
6
|
+
|
|
7
|
+
//
|
|
8
|
+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
9
|
+
// See https://llvm.org/LICENSE.txt for license information.
|
|
10
|
+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
11
|
+
//
|
|
12
|
+
|
|
13
|
+
#ifndef GGML_SYCL_CONV_HPP
|
|
14
|
+
#define GGML_SYCL_CONV_HPP
|
|
15
|
+
|
|
16
|
+
#include "common.hpp"
|
|
17
|
+
|
|
18
|
+
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
|
19
|
+
const ggml_tensor *src1, ggml_tensor *dst);
|
|
20
|
+
|
|
21
|
+
#endif // GGML_SYCL_CONV_HPP
|