@fugood/llama.node 0.3.16 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +6 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +44 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +374 -19
- package/src/LlamaCompletionWorker.h +31 -10
- package/src/LlamaContext.cpp +216 -7
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
- package/src/llama.cpp/.github/workflows/build.yml +89 -767
- package/src/llama.cpp/.github/workflows/docker.yml +9 -6
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +19 -23
- package/src/llama.cpp/CMakeLists.txt +11 -1
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +35 -4
- package/src/llama.cpp/common/arg.cpp +844 -121
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +129 -107
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +64 -518
- package/src/llama.cpp/common/common.h +35 -45
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +31 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
- package/src/llama.cpp/common/minja/minja.hpp +186 -127
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +60 -50
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +2 -32
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +76 -106
- package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
- package/src/llama.cpp/ggml/src/ggml.c +170 -265
- package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
- package/src/llama.cpp/include/llama.h +82 -22
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +5 -3
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +4 -2
- package/src/llama.cpp/src/llama-adapter.cpp +43 -1
- package/src/llama.cpp/src/llama-arch.cpp +163 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +91 -16
- package/src/llama.cpp/src/llama-chat.h +7 -2
- package/src/llama.cpp/src/llama-context.cpp +479 -575
- package/src/llama.cpp/src/llama-context.h +44 -33
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +209 -157
- package/src/llama.cpp/src/llama-graph.h +38 -14
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
- package/src/llama.cpp/src/llama-kv-cache.h +283 -171
- package/src/llama.cpp/src/llama-memory.h +12 -2
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +1803 -330
- package/src/llama.cpp/src/llama-model.h +21 -2
- package/src/llama.cpp/src/llama-quant.cpp +33 -10
- package/src/llama.cpp/src/llama-sampling.cpp +25 -7
- package/src/llama.cpp/src/llama-vocab.cpp +86 -10
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +15 -1
- package/src/llama.cpp/tests/CMakeLists.txt +52 -31
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
- package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
- package/src/llama.cpp/tests/test-chat.cpp +15 -3
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
- package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
- package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
- package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
- package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
- package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
- package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
- package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
- package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
- package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
- package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
- package/src/llama.cpp/examples/llava/clip.h +0 -118
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/llava.cpp +0 -574
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
- package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
|
@@ -12,115 +12,125 @@
|
|
|
12
12
|
|
|
13
13
|
#include "im2col.hpp"
|
|
14
14
|
|
|
15
|
+
#include <sycl/sycl.hpp>
|
|
16
|
+
#include <type_traits> // For std::is_same_v
|
|
17
|
+
|
|
18
|
+
#include "ggml.h"
|
|
19
|
+
|
|
15
20
|
template <typename T>
|
|
16
|
-
static void im2col_kernel(
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1,
|
|
20
|
-
const sycl::nd_item<3> &item_ct1) {
|
|
21
|
+
static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_t offset_delta, int64_t IC, int64_t IW,
|
|
22
|
+
int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
|
|
23
|
+
int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ct1) {
|
|
21
24
|
const int64_t work_group_size = item_ct1.get_local_range(2);
|
|
22
|
-
const int64_t global_id
|
|
25
|
+
const int64_t global_id = item_ct1.get_local_id(2) + (work_group_size * item_ct1.get_group(2));
|
|
23
26
|
|
|
24
27
|
// make each work-item deal with more elements since sycl global range can not exceed max int
|
|
25
|
-
for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) {
|
|
26
|
-
|
|
28
|
+
for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) {
|
|
27
29
|
const int64_t ksize = OW * (KH > 1 ? KW : 1);
|
|
28
|
-
const int64_t kx
|
|
29
|
-
const int64_t kd
|
|
30
|
-
const int64_t ky
|
|
31
|
-
const int64_t ix
|
|
32
|
-
|
|
33
|
-
const int64_t
|
|
34
|
-
const int64_t
|
|
35
|
-
const int64_t
|
|
36
|
-
|
|
37
|
-
const int64_t iiw = ix * s0 + kx * d0 - p0;
|
|
38
|
-
const int64_t iih = oh * s1 + ky * d1 - p1;
|
|
39
|
-
|
|
40
|
-
const int64_t offset_dst =
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
|
|
30
|
+
const int64_t kx = i / ksize;
|
|
31
|
+
const int64_t kd = kx * ksize;
|
|
32
|
+
const int64_t ky = (i - kd) / OW;
|
|
33
|
+
const int64_t ix = i % OW;
|
|
34
|
+
|
|
35
|
+
const int64_t oh = item_ct1.get_group(1);
|
|
36
|
+
const int64_t batch = item_ct1.get_group(0) / IC;
|
|
37
|
+
const int64_t ic = item_ct1.get_group(0) % IC;
|
|
38
|
+
|
|
39
|
+
const int64_t iiw = (ix * s0) + (kx * d0) - p0;
|
|
40
|
+
const int64_t iih = (oh * s1) + (ky * d1) - p1;
|
|
41
|
+
|
|
42
|
+
const int64_t offset_dst = (((batch * OH + oh) * OW + ix) * CHW) + (ic * (KW * KH) + ky * KW + kx);
|
|
43
|
+
|
|
44
|
+
const int64_t offset_src_base = (ic * offset_delta) + (batch * batch_offset);
|
|
45
|
+
const int64_t offset_src = offset_src_base + (iih * IW) + iiw;
|
|
46
|
+
|
|
47
|
+
const bool out_of_bounds = (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW);
|
|
48
|
+
const float src_val = out_of_bounds ? 0.0f : x[offset_src];
|
|
49
|
+
|
|
50
|
+
if constexpr (std::is_same_v<T, sycl::half>) {
|
|
51
|
+
dst[offset_dst] = sycl::half(src_val);
|
|
52
|
+
} else if constexpr (std::is_same_v<T, float>) {
|
|
53
|
+
dst[offset_dst] = src_val;
|
|
53
54
|
}
|
|
54
55
|
}
|
|
55
56
|
}
|
|
56
57
|
|
|
57
58
|
template <typename T>
|
|
58
|
-
static void
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
int s0, int s1, int p0, int p1, int d0, int d1,
|
|
62
|
-
queue_ptr stream) {
|
|
59
|
+
static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,
|
|
60
|
+
int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta,
|
|
61
|
+
int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) {
|
|
63
62
|
const int64_t parallel_elements = OW * KW * KH;
|
|
64
|
-
const int64_t num_blocks
|
|
63
|
+
const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
|
|
65
64
|
|
|
66
65
|
// decrease global range when it exceeds the max int
|
|
67
66
|
int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);
|
|
67
|
+
|
|
68
68
|
sycl::range<3> block_nums(batch * IC, OH, num_blocks);
|
|
69
69
|
sycl::range<3> local_range(1, 1, local_size);
|
|
70
70
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
71
|
+
const int64_t CHW = IC * KH * KW;
|
|
72
|
+
|
|
73
|
+
stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
|
|
74
|
+
im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1,
|
|
75
|
+
p0, p1, d0, d1, item_ct1);
|
|
76
|
+
});
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
static void im2col_sycl_f16(const float * x, sycl::half * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH,
|
|
80
|
+
int64_t KW, int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset,
|
|
81
|
+
int64_t offset_delta, int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) {
|
|
82
|
+
if (!stream->get_device().has(sycl::aspect::fp16)) {
|
|
83
|
+
throw sycl::exception(sycl::make_error_code(sycl::errc::kernel_not_supported),
|
|
84
|
+
"Device does not support half precision (fp16) operations!");
|
|
82
85
|
}
|
|
86
|
+
im2col_sycl_internal<sycl::half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0,
|
|
87
|
+
p1, d0, d1, stream);
|
|
83
88
|
}
|
|
84
89
|
|
|
85
|
-
void
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
90
|
+
static void im2col_sycl_f32(const float * x, float * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,
|
|
91
|
+
int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, int s0,
|
|
92
|
+
int s1, int p0, int p1, int d0, int d1, queue_ptr stream) {
|
|
93
|
+
im2col_sycl_internal<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1,
|
|
94
|
+
d0, d1, stream);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
98
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
99
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
89
100
|
|
|
90
|
-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
91
101
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
92
102
|
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
|
93
103
|
|
|
94
|
-
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
|
95
|
-
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
|
96
|
-
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
|
|
97
|
-
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
|
|
98
|
-
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
|
|
99
|
-
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
|
104
|
+
const int32_t s0 = ((const int32_t *) (dst->op_params))[0];
|
|
105
|
+
const int32_t s1 = ((const int32_t *) (dst->op_params))[1];
|
|
106
|
+
const int32_t p0 = ((const int32_t *) (dst->op_params))[2];
|
|
107
|
+
const int32_t p1 = ((const int32_t *) (dst->op_params))[3];
|
|
108
|
+
const int32_t d0 = ((const int32_t *) (dst->op_params))[4];
|
|
109
|
+
const int32_t d1 = ((const int32_t *) (dst->op_params))[5];
|
|
100
110
|
|
|
101
|
-
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
|
111
|
+
const bool is_2D = ((const int32_t *) (dst->op_params))[6] == 1;
|
|
102
112
|
|
|
103
113
|
const int64_t IC = src1->ne[is_2D ? 2 : 1];
|
|
104
114
|
const int64_t IH = is_2D ? src1->ne[1] : 1;
|
|
105
|
-
const int64_t IW =
|
|
115
|
+
const int64_t IW = src1->ne[0];
|
|
106
116
|
|
|
107
117
|
const int64_t KH = is_2D ? src0->ne[1] : 1;
|
|
108
|
-
const int64_t KW =
|
|
118
|
+
const int64_t KW = src0->ne[0];
|
|
109
119
|
|
|
110
120
|
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
|
111
|
-
const int64_t OW =
|
|
121
|
+
const int64_t OW = dst->ne[1];
|
|
112
122
|
|
|
113
|
-
const size_t
|
|
114
|
-
const int64_t batch
|
|
115
|
-
const size_t
|
|
123
|
+
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / sizeof(float);
|
|
124
|
+
const int64_t batch = src1->ne[is_2D ? 3 : 2];
|
|
125
|
+
const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / sizeof(float);
|
|
126
|
+
|
|
127
|
+
queue_ptr stream = ctx.stream();
|
|
116
128
|
|
|
117
129
|
if (dst->type == GGML_TYPE_F16) {
|
|
118
|
-
|
|
130
|
+
im2col_sycl_f16((const float *) src1->data, (sycl::half *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch,
|
|
131
|
+
batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
|
|
119
132
|
} else {
|
|
120
|
-
|
|
133
|
+
im2col_sycl_f32((const float *) src1->data, (float *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch,
|
|
134
|
+
batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
|
|
121
135
|
}
|
|
122
|
-
|
|
123
|
-
GGML_UNUSED(src0);
|
|
124
|
-
GGML_UNUSED(src0_dd);
|
|
125
|
-
GGML_UNUSED(ctx);
|
|
126
136
|
}
|
|
@@ -16,8 +16,6 @@
|
|
|
16
16
|
#include "common.hpp"
|
|
17
17
|
|
|
18
18
|
void ggml_sycl_op_im2col(
|
|
19
|
-
ggml_backend_sycl_context & ctx,
|
|
20
|
-
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
|
|
21
|
-
const queue_ptr &main_stream);
|
|
19
|
+
ggml_backend_sycl_context & ctx, ggml_tensor *dst);
|
|
22
20
|
|
|
23
21
|
#endif // GGML_SYCL_IM2COL_HPP
|
|
@@ -1,6 +1,61 @@
|
|
|
1
1
|
#include "mmvq.hpp"
|
|
2
|
+
|
|
3
|
+
#include "ggml.h"
|
|
4
|
+
#include "common.hpp"
|
|
5
|
+
#include "quants.hpp"
|
|
2
6
|
#include "vecdotq.hpp"
|
|
3
|
-
|
|
7
|
+
|
|
8
|
+
template <typename reorder_vec_dot_q_sycl>
|
|
9
|
+
static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
|
10
|
+
const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) {
|
|
11
|
+
using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>;
|
|
12
|
+
using block_traits = typename block_type::traits;
|
|
13
|
+
|
|
14
|
+
const auto sg = nd_item.get_sub_group();
|
|
15
|
+
const int sg_range = sg.get_group_linear_range();
|
|
16
|
+
const int workgroup_id = nd_item.get_group_linear_id();
|
|
17
|
+
const int sg_id = sg.get_group_linear_id();
|
|
18
|
+
const int row = workgroup_id * sg_range + sg_id;
|
|
19
|
+
|
|
20
|
+
if (row >= nrows) {
|
|
21
|
+
return;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
const int blocks_per_row = ncols / block_traits::qk;
|
|
25
|
+
constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
|
|
26
|
+
constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
|
|
27
|
+
const int nblocks = nrows * (ncols / block_traits::qk);
|
|
28
|
+
|
|
29
|
+
static_assert(blocks_per_subgroup > 0);
|
|
30
|
+
static_assert(block_elements_per_subgroup > 0);
|
|
31
|
+
|
|
32
|
+
const block_q8_1 * y = (const block_q8_1 *) vy;
|
|
33
|
+
|
|
34
|
+
float partial_sum = 0.0f;
|
|
35
|
+
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
|
|
36
|
+
const int ibx = row * blocks_per_row + i; // x block index
|
|
37
|
+
// TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
|
|
38
|
+
const int bx_offset = block_type::get_block_offset(ibx);
|
|
39
|
+
const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
|
|
40
|
+
|
|
41
|
+
// Y block index that aligns with ibx
|
|
42
|
+
const int iby = i * block_type::block_to_q8_1_ratio();
|
|
43
|
+
|
|
44
|
+
#pragma unroll
|
|
45
|
+
for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) {
|
|
46
|
+
// x block quant index when casting the quants to int
|
|
47
|
+
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
|
|
48
|
+
|
|
49
|
+
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks);
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>());
|
|
54
|
+
|
|
55
|
+
if (sg.leader()) {
|
|
56
|
+
dst[row] = sum;
|
|
57
|
+
}
|
|
58
|
+
}
|
|
4
59
|
|
|
5
60
|
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl>
|
|
6
61
|
static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
|
@@ -480,26 +535,39 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
|
|
|
480
535
|
}
|
|
481
536
|
}
|
|
482
537
|
|
|
483
|
-
static void
|
|
484
|
-
|
|
485
|
-
|
|
538
|
+
static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
|
539
|
+
const int nrows, dpct::queue_ptr stream) {
|
|
540
|
+
GGML_ASSERT(ncols % QK4_0 == 0);
|
|
541
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
|
542
|
+
constexpr size_t num_subgroups = 16;
|
|
543
|
+
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
|
544
|
+
|
|
545
|
+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
|
|
546
|
+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
547
|
+
|
|
548
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
549
|
+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
|
550
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
551
|
+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
|
|
552
|
+
nd_item);
|
|
553
|
+
});
|
|
554
|
+
});
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
|
|
486
558
|
dpct::queue_ptr stream) {
|
|
487
559
|
GGML_ASSERT(ncols % QK4_0 == 0);
|
|
488
560
|
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
489
561
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
490
562
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
491
|
-
{
|
|
492
|
-
|
|
493
|
-
stream->submit([&](sycl::handler &cgh) {
|
|
494
563
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
});
|
|
564
|
+
{
|
|
565
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
566
|
+
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
567
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
568
|
+
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
|
569
|
+
vx, vy, dst, ncols, nrows, item_ct1);
|
|
570
|
+
});
|
|
503
571
|
});
|
|
504
572
|
}
|
|
505
573
|
}
|
|
@@ -672,6 +740,27 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
|
|
672
740
|
}
|
|
673
741
|
}
|
|
674
742
|
|
|
743
|
+
static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
|
|
744
|
+
const int nrows, dpct::queue_ptr stream) {
|
|
745
|
+
GGML_ASSERT(ncols % QK_K == 0);
|
|
746
|
+
|
|
747
|
+
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
|
|
748
|
+
constexpr size_t num_subgroups = 16;
|
|
749
|
+
GGML_ASSERT(block_num_y % num_subgroups == 0);
|
|
750
|
+
|
|
751
|
+
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
|
752
|
+
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
|
753
|
+
|
|
754
|
+
stream->submit([&](sycl::handler & cgh) {
|
|
755
|
+
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
|
756
|
+
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
757
|
+
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
|
|
758
|
+
nrows, nd_item);
|
|
759
|
+
});
|
|
760
|
+
});
|
|
761
|
+
}
|
|
762
|
+
|
|
763
|
+
|
|
675
764
|
static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
|
676
765
|
float *dst, const int ncols,
|
|
677
766
|
const int nrows,
|
|
@@ -916,93 +1005,100 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
|
|
916
1005
|
}
|
|
917
1006
|
}
|
|
918
1007
|
|
|
919
|
-
void ggml_sycl_op_mul_mat_vec_q(
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
const int64_t src1_ncols, const int64_t src1_padded_col_size,
|
|
925
|
-
const dpct::queue_ptr &stream) {
|
|
926
|
-
|
|
1008
|
+
void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
|
|
1009
|
+
ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
|
1010
|
+
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low,
|
|
1011
|
+
const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size,
|
|
1012
|
+
const dpct::queue_ptr & stream) {
|
|
927
1013
|
const int64_t ne10 = src1->ne[0];
|
|
928
1014
|
GGML_ASSERT(ne10 % QK8_1 == 0);
|
|
929
1015
|
|
|
930
|
-
const int64_t ne00
|
|
1016
|
+
const int64_t ne00 = src0->ne[0];
|
|
931
1017
|
const int64_t row_diff = row_high - row_low;
|
|
932
1018
|
|
|
933
1019
|
int id;
|
|
934
|
-
SYCL_CHECK(
|
|
935
|
-
CHECK_TRY_ERROR(id = get_current_device_id()));
|
|
1020
|
+
SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id()));
|
|
936
1021
|
const size_t q8_1_ts = sizeof(block_q8_1);
|
|
937
1022
|
const size_t q8_1_bs = QK8_1;
|
|
938
1023
|
// the main device has a larger memory buffer to hold the results from all GPUs
|
|
939
1024
|
// nrows_dst == nrows of the matrix that the kernel writes into
|
|
940
1025
|
|
|
941
|
-
for (int i = 0; i < src1_ncols; i++)
|
|
942
|
-
{
|
|
1026
|
+
for (int i = 0; i < src1_ncols; i++) {
|
|
943
1027
|
const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs;
|
|
944
|
-
const char* src1_ddq_i_bs
|
|
945
|
-
float*
|
|
1028
|
+
const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset;
|
|
1029
|
+
float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
|
|
946
1030
|
switch (src0->type) {
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1031
|
+
case GGML_TYPE_Q4_0:
|
|
1032
|
+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
|
1033
|
+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
|
1034
|
+
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n");
|
|
1035
|
+
reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1036
|
+
} else {
|
|
1037
|
+
GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n");
|
|
1038
|
+
mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1039
|
+
}
|
|
1040
|
+
break;
|
|
1041
|
+
case GGML_TYPE_Q4_1:
|
|
1042
|
+
mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1043
|
+
break;
|
|
1044
|
+
case GGML_TYPE_Q5_0:
|
|
1045
|
+
mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1046
|
+
break;
|
|
1047
|
+
case GGML_TYPE_Q5_1:
|
|
1048
|
+
mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1049
|
+
break;
|
|
1050
|
+
case GGML_TYPE_Q8_0:
|
|
1051
|
+
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1052
|
+
break;
|
|
1053
|
+
case GGML_TYPE_Q2_K:
|
|
1054
|
+
mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1055
|
+
break;
|
|
1056
|
+
case GGML_TYPE_Q3_K:
|
|
1057
|
+
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1058
|
+
break;
|
|
1059
|
+
case GGML_TYPE_Q4_K:
|
|
1060
|
+
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
|
|
1061
|
+
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
|
|
1062
|
+
reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1063
|
+
} else {
|
|
1064
|
+
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1065
|
+
}
|
|
1066
|
+
break;
|
|
1067
|
+
case GGML_TYPE_Q5_K:
|
|
1068
|
+
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1069
|
+
break;
|
|
1070
|
+
case GGML_TYPE_Q6_K:
|
|
1071
|
+
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1072
|
+
break;
|
|
1073
|
+
case GGML_TYPE_IQ1_S:
|
|
1074
|
+
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1075
|
+
break;
|
|
1076
|
+
case GGML_TYPE_IQ1_M:
|
|
1077
|
+
mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1078
|
+
break;
|
|
1079
|
+
case GGML_TYPE_IQ2_XXS:
|
|
1080
|
+
mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1081
|
+
break;
|
|
1082
|
+
case GGML_TYPE_IQ2_XS:
|
|
1083
|
+
mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1084
|
+
break;
|
|
1085
|
+
case GGML_TYPE_IQ2_S:
|
|
1086
|
+
mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1087
|
+
break;
|
|
1088
|
+
case GGML_TYPE_IQ3_XXS:
|
|
1089
|
+
mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1090
|
+
break;
|
|
1091
|
+
case GGML_TYPE_IQ3_S:
|
|
1092
|
+
mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1093
|
+
break;
|
|
1094
|
+
case GGML_TYPE_IQ4_NL:
|
|
1095
|
+
mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1096
|
+
break;
|
|
1097
|
+
case GGML_TYPE_IQ4_XS:
|
|
1098
|
+
mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
|
|
1099
|
+
break;
|
|
1100
|
+
default:
|
|
1101
|
+
GGML_ABORT("fatal error");
|
|
1006
1102
|
}
|
|
1007
1103
|
}
|
|
1008
1104
|
GGML_UNUSED(src1);
|