@fugood/llama.node 0.3.16 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +6 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +44 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +374 -19
- package/src/LlamaCompletionWorker.h +31 -10
- package/src/LlamaContext.cpp +216 -7
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
- package/src/llama.cpp/.github/workflows/build.yml +89 -767
- package/src/llama.cpp/.github/workflows/docker.yml +9 -6
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +19 -23
- package/src/llama.cpp/CMakeLists.txt +11 -1
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +35 -4
- package/src/llama.cpp/common/arg.cpp +844 -121
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +129 -107
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +64 -518
- package/src/llama.cpp/common/common.h +35 -45
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +31 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
- package/src/llama.cpp/common/minja/minja.hpp +186 -127
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +60 -50
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +2 -32
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +76 -106
- package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
- package/src/llama.cpp/ggml/src/ggml.c +170 -265
- package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
- package/src/llama.cpp/include/llama.h +82 -22
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +5 -3
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +4 -2
- package/src/llama.cpp/src/llama-adapter.cpp +43 -1
- package/src/llama.cpp/src/llama-arch.cpp +163 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +91 -16
- package/src/llama.cpp/src/llama-chat.h +7 -2
- package/src/llama.cpp/src/llama-context.cpp +479 -575
- package/src/llama.cpp/src/llama-context.h +44 -33
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +209 -157
- package/src/llama.cpp/src/llama-graph.h +38 -14
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
- package/src/llama.cpp/src/llama-kv-cache.h +283 -171
- package/src/llama.cpp/src/llama-memory.h +12 -2
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +1803 -330
- package/src/llama.cpp/src/llama-model.h +21 -2
- package/src/llama.cpp/src/llama-quant.cpp +33 -10
- package/src/llama.cpp/src/llama-sampling.cpp +25 -7
- package/src/llama.cpp/src/llama-vocab.cpp +86 -10
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +15 -1
- package/src/llama.cpp/tests/CMakeLists.txt +52 -31
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
- package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
- package/src/llama.cpp/tests/test-chat.cpp +15 -3
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
- package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
- package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
- package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
- package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
- package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
- package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
- package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
- package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
- package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
- package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
- package/src/llama.cpp/examples/llava/clip.h +0 -118
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/llava.cpp +0 -574
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
- package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
|
@@ -16,9 +16,18 @@
|
|
|
16
16
|
#include <sycl/sycl.hpp>
|
|
17
17
|
#include <sycl/half_type.hpp>
|
|
18
18
|
#include <syclcompat/math.hpp>
|
|
19
|
-
#include <oneapi/mkl.hpp>
|
|
20
19
|
#include <map>
|
|
21
20
|
|
|
21
|
+
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
|
|
22
|
+
#include <oneapi/mkl.hpp>
|
|
23
|
+
// Allow to use the same namespace for Intel oneMKL and oneMath
|
|
24
|
+
namespace oneapi {
|
|
25
|
+
namespace math = mkl;
|
|
26
|
+
}
|
|
27
|
+
#else
|
|
28
|
+
#include <oneapi/math.hpp>
|
|
29
|
+
#endif
|
|
30
|
+
|
|
22
31
|
#include "ggml.h"
|
|
23
32
|
|
|
24
33
|
#if defined(__linux__)
|
|
@@ -83,13 +92,32 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
|
|
|
83
92
|
}
|
|
84
93
|
|
|
85
94
|
template <typename Ts> struct matrix_info_t {
|
|
86
|
-
oneapi::
|
|
95
|
+
oneapi::math::transpose transpose_info[2];
|
|
87
96
|
Ts value_info[2];
|
|
88
97
|
std::int64_t size_info[3];
|
|
89
98
|
std::int64_t ld_info[3];
|
|
90
99
|
std::int64_t groupsize_info;
|
|
91
100
|
};
|
|
92
101
|
|
|
102
|
+
inline auto get_onemath_backend(sycl::queue& queue)
|
|
103
|
+
#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
|
104
|
+
-> sycl::queue&
|
|
105
|
+
#endif
|
|
106
|
+
{
|
|
107
|
+
// If the backend is known at compile-time, use oneMath backend_selector to use
|
|
108
|
+
// compile-time dispatching and avoid the need to dlopen libraries. Otherwise
|
|
109
|
+
// fallback to runtime dispatching.
|
|
110
|
+
#if defined(GGML_SYCL_NVIDIA)
|
|
111
|
+
return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
|
|
112
|
+
#elif defined(GGML_SYCL_AMD)
|
|
113
|
+
return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
|
|
114
|
+
#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
|
|
115
|
+
return queue;
|
|
116
|
+
#else
|
|
117
|
+
static_assert(false, "Unsupported backend");
|
|
118
|
+
#endif
|
|
119
|
+
}
|
|
120
|
+
|
|
93
121
|
namespace dpct
|
|
94
122
|
{
|
|
95
123
|
typedef sycl::queue *queue_ptr;
|
|
@@ -1686,26 +1714,18 @@ namespace dpct
|
|
|
1686
1714
|
|
|
1687
1715
|
namespace detail
|
|
1688
1716
|
{
|
|
1689
|
-
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
|
|
1701
|
-
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
|
|
1702
|
-
a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
|
1703
|
-
beta_value, data_c, ldc);
|
|
1704
|
-
#else
|
|
1705
|
-
oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
|
|
1706
|
-
beta_value, data_c, ldc);
|
|
1707
|
-
#endif
|
|
1708
|
-
}
|
|
1717
|
+
template <class Ta, class Tb, class Tc, class Ts>
|
|
1718
|
+
inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
|
1719
|
+
int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
|
|
1720
|
+
const void * beta, void * c, int ldc) {
|
|
1721
|
+
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
|
1722
|
+
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
|
1723
|
+
auto data_a = get_memory<const Ta>(a);
|
|
1724
|
+
auto data_b = get_memory<const Tb>(b);
|
|
1725
|
+
auto data_c = get_memory<Tc>(c);
|
|
1726
|
+
oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
|
|
1727
|
+
lda, data_b, ldb, beta_value, data_c, ldc);
|
|
1728
|
+
}
|
|
1709
1729
|
|
|
1710
1730
|
template <typename VecT, class BinaryOperation, class = void>
|
|
1711
1731
|
class vectorized_binary
|
|
@@ -1735,7 +1755,7 @@ namespace dpct
|
|
|
1735
1755
|
};
|
|
1736
1756
|
|
|
1737
1757
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
1738
|
-
inline void gemm_batch_impl(sycl::queue & q, oneapi::
|
|
1758
|
+
inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
|
|
1739
1759
|
int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
|
|
1740
1760
|
int ldb, const void * beta, void ** c, int ldc, int batch_size,
|
|
1741
1761
|
matrix_info_t<float> * matrix_info) {
|
|
@@ -1754,48 +1774,28 @@ namespace dpct
|
|
|
1754
1774
|
matrix_info->ld_info[2] = ldc;
|
|
1755
1775
|
matrix_info->groupsize_info = batch_size;
|
|
1756
1776
|
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
matrix_info->
|
|
1761
|
-
|
|
1762
|
-
reinterpret_cast<
|
|
1763
|
-
matrix_info->ld_info + 1,
|
|
1764
|
-
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
1765
|
-
#else
|
|
1766
|
-
sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
|
|
1767
|
-
q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
|
|
1768
|
-
matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
|
|
1769
|
-
reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
|
|
1770
|
-
matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
|
|
1771
|
-
reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
1772
|
-
#endif
|
|
1777
|
+
sycl::event e = oneapi::math::blas::column_major::gemm_batch(
|
|
1778
|
+
get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
|
|
1779
|
+
matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
|
|
1780
|
+
reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
|
|
1781
|
+
reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
|
|
1782
|
+
reinterpret_cast<Ts *>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
|
|
1783
|
+
matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
|
|
1773
1784
|
}
|
|
1774
1785
|
|
|
1775
1786
|
template <class Ta, class Tb, class Tc, class Ts>
|
|
1776
|
-
inline void
|
|
1777
|
-
|
|
1778
|
-
|
|
1779
|
-
|
|
1780
|
-
long long int stride_a, const void *b, int ldb,
|
|
1781
|
-
long long int stride_b, const void *beta, void *c,
|
|
1782
|
-
int ldc, long long int stride_c, int batch_size)
|
|
1783
|
-
{
|
|
1787
|
+
inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
|
|
1788
|
+
int m, int n, int k, const void * alpha, const void * a, int lda,
|
|
1789
|
+
long long int stride_a, const void * b, int ldb, long long int stride_b,
|
|
1790
|
+
const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
|
|
1784
1791
|
Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
|
|
1785
1792
|
Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
|
|
1786
1793
|
auto data_a = get_memory<const Ta>(a);
|
|
1787
1794
|
auto data_b = get_memory<const Tb>(b);
|
|
1788
1795
|
auto data_c = get_memory<Tc>(c);
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
|
|
1793
|
-
batch_size);
|
|
1794
|
-
#else
|
|
1795
|
-
oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
|
|
1796
|
-
stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
|
|
1797
|
-
stride_c, batch_size);
|
|
1798
|
-
#endif
|
|
1796
|
+
oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
|
|
1797
|
+
data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
|
|
1798
|
+
data_c, ldc, stride_c, batch_size);
|
|
1799
1799
|
}
|
|
1800
1800
|
|
|
1801
1801
|
} // namespace detail
|
|
@@ -2259,13 +2259,10 @@ namespace dpct
|
|
|
2259
2259
|
sycl::range<3>(x, y, 1), direction);
|
|
2260
2260
|
}
|
|
2261
2261
|
|
|
2262
|
-
inline void gemm(sycl::queue &q, oneapi::
|
|
2263
|
-
|
|
2264
|
-
const void *
|
|
2265
|
-
|
|
2266
|
-
const void *beta, void *c, library_data_t c_type, int ldc,
|
|
2267
|
-
library_data_t scaling_type)
|
|
2268
|
-
{
|
|
2262
|
+
inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
|
|
2263
|
+
int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
|
|
2264
|
+
library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
|
|
2265
|
+
library_data_t scaling_type) {
|
|
2269
2266
|
if (scaling_type == library_data_t::real_float &&
|
|
2270
2267
|
c_type == library_data_t::complex_float)
|
|
2271
2268
|
{
|
|
@@ -2329,9 +2326,8 @@ namespace dpct
|
|
|
2329
2326
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2330
2327
|
library_data_t::real_float, library_data_t::real_float):
|
|
2331
2328
|
{
|
|
2332
|
-
detail::gemm_impl<oneapi::
|
|
2333
|
-
|
|
2334
|
-
ldb, beta, c, ldc);
|
|
2329
|
+
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
|
2330
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
|
2335
2331
|
break;
|
|
2336
2332
|
}
|
|
2337
2333
|
case detail::get_type_combination_id(
|
|
@@ -2369,8 +2365,7 @@ namespace dpct
|
|
|
2369
2365
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2370
2366
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
2371
2367
|
{
|
|
2372
|
-
detail::gemm_impl<oneapi::
|
|
2373
|
-
oneapi::mkl::bfloat16, float>(
|
|
2368
|
+
detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
|
2374
2369
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
|
|
2375
2370
|
break;
|
|
2376
2371
|
}
|
|
@@ -2390,7 +2385,7 @@ namespace dpct
|
|
|
2390
2385
|
default:
|
|
2391
2386
|
throw std::runtime_error("the combination of data type is unsupported");
|
|
2392
2387
|
}
|
|
2393
|
-
}
|
|
2388
|
+
} // gemm()
|
|
2394
2389
|
|
|
2395
2390
|
/// Computes a batch of matrix-matrix product with general matrices.
|
|
2396
2391
|
/// \param [in] q The queue where the routine should be executed.
|
|
@@ -2412,7 +2407,7 @@ namespace dpct
|
|
|
2412
2407
|
/// \param [in] ldc Leading dimension of C.
|
|
2413
2408
|
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
|
2414
2409
|
/// \param [in] scaling_type Data type of the scaling factors.
|
|
2415
|
-
inline void gemm_batch(sycl::queue & q, oneapi::
|
|
2410
|
+
inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
|
2416
2411
|
int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
|
|
2417
2412
|
const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
|
|
2418
2413
|
library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
|
|
@@ -2450,7 +2445,7 @@ namespace dpct
|
|
|
2450
2445
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2451
2446
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
2452
2447
|
{
|
|
2453
|
-
detail::gemm_batch_impl<oneapi::
|
|
2448
|
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
|
2454
2449
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
2455
2450
|
break;
|
|
2456
2451
|
}
|
|
@@ -2458,7 +2453,7 @@ namespace dpct
|
|
|
2458
2453
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2459
2454
|
library_data_t::real_float, library_data_t::real_float):
|
|
2460
2455
|
{
|
|
2461
|
-
detail::gemm_batch_impl<oneapi::
|
|
2456
|
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
|
2462
2457
|
q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
|
|
2463
2458
|
break;
|
|
2464
2459
|
}
|
|
@@ -2534,15 +2529,11 @@ namespace dpct
|
|
|
2534
2529
|
/// \param [in] stride_c Stride between the different C matrices.
|
|
2535
2530
|
/// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
|
|
2536
2531
|
/// \param [in] scaling_type Data type of the scaling factors.
|
|
2537
|
-
inline void gemm_batch(sycl::queue &q, oneapi::
|
|
2538
|
-
|
|
2539
|
-
|
|
2540
|
-
|
|
2541
|
-
|
|
2542
|
-
const void *beta, void *c, library_data_t c_type,
|
|
2543
|
-
int ldc, long long int stride_c, int batch_size,
|
|
2544
|
-
library_data_t scaling_type)
|
|
2545
|
-
{
|
|
2532
|
+
inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
|
|
2533
|
+
int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
|
|
2534
|
+
long long int stride_a, const void * b, library_data_t b_type, int ldb,
|
|
2535
|
+
long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
|
|
2536
|
+
long long int stride_c, int batch_size, library_data_t scaling_type) {
|
|
2546
2537
|
if (scaling_type == library_data_t::real_float &&
|
|
2547
2538
|
c_type == library_data_t::complex_float)
|
|
2548
2539
|
{
|
|
@@ -2611,20 +2602,18 @@ namespace dpct
|
|
|
2611
2602
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2612
2603
|
library_data_t::real_bfloat16, library_data_t::real_float):
|
|
2613
2604
|
{
|
|
2614
|
-
detail::gemm_batch_impl<oneapi::
|
|
2615
|
-
|
|
2616
|
-
|
|
2617
|
-
beta, c, ldc, stride_c, batch_size);
|
|
2605
|
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
|
|
2606
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
|
2607
|
+
batch_size);
|
|
2618
2608
|
break;
|
|
2619
2609
|
}
|
|
2620
2610
|
case detail::get_type_combination_id(
|
|
2621
2611
|
library_data_t::real_bfloat16, library_data_t::real_bfloat16,
|
|
2622
2612
|
library_data_t::real_float, library_data_t::real_float):
|
|
2623
2613
|
{
|
|
2624
|
-
detail::gemm_batch_impl<oneapi::
|
|
2625
|
-
|
|
2626
|
-
|
|
2627
|
-
stride_c, batch_size);
|
|
2614
|
+
detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
|
|
2615
|
+
q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
|
|
2616
|
+
batch_size);
|
|
2628
2617
|
break;
|
|
2629
2618
|
}
|
|
2630
2619
|
#endif
|