@fugood/llama.node 0.3.16 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +6 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +44 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +374 -19
- package/src/LlamaCompletionWorker.h +31 -10
- package/src/LlamaContext.cpp +216 -7
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
- package/src/llama.cpp/.github/workflows/build.yml +89 -767
- package/src/llama.cpp/.github/workflows/docker.yml +9 -6
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +19 -23
- package/src/llama.cpp/CMakeLists.txt +11 -1
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +35 -4
- package/src/llama.cpp/common/arg.cpp +844 -121
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +129 -107
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +64 -518
- package/src/llama.cpp/common/common.h +35 -45
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +31 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
- package/src/llama.cpp/common/minja/minja.hpp +186 -127
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +60 -50
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +2 -32
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +76 -106
- package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
- package/src/llama.cpp/ggml/src/ggml.c +170 -265
- package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
- package/src/llama.cpp/include/llama.h +82 -22
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +5 -3
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +4 -2
- package/src/llama.cpp/src/llama-adapter.cpp +43 -1
- package/src/llama.cpp/src/llama-arch.cpp +163 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +91 -16
- package/src/llama.cpp/src/llama-chat.h +7 -2
- package/src/llama.cpp/src/llama-context.cpp +479 -575
- package/src/llama.cpp/src/llama-context.h +44 -33
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +209 -157
- package/src/llama.cpp/src/llama-graph.h +38 -14
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
- package/src/llama.cpp/src/llama-kv-cache.h +283 -171
- package/src/llama.cpp/src/llama-memory.h +12 -2
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +1803 -330
- package/src/llama.cpp/src/llama-model.h +21 -2
- package/src/llama.cpp/src/llama-quant.cpp +33 -10
- package/src/llama.cpp/src/llama-sampling.cpp +25 -7
- package/src/llama.cpp/src/llama-vocab.cpp +86 -10
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +15 -1
- package/src/llama.cpp/tests/CMakeLists.txt +52 -31
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
- package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
- package/src/llama.cpp/tests/test-chat.cpp +15 -3
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
- package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
- package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
- package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
- package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
- package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
- package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
- package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
- package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
- package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
- package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
- package/src/llama.cpp/examples/llava/clip.h +0 -118
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/llava.cpp +0 -574
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
- package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
|
@@ -1,9 +1,15 @@
|
|
|
1
1
|
#include "rope.hpp"
|
|
2
|
+
#include "ggml-sycl/common.hpp"
|
|
3
|
+
#include "ggml.h"
|
|
2
4
|
|
|
3
5
|
struct rope_corr_dims {
|
|
4
6
|
float v[2];
|
|
5
7
|
};
|
|
6
8
|
|
|
9
|
+
struct mrope_sections {
|
|
10
|
+
int v[4];
|
|
11
|
+
};
|
|
12
|
+
|
|
7
13
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
8
14
|
const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
|
|
9
15
|
return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
|
|
@@ -28,23 +34,21 @@ static void rope_yarn(
|
|
|
28
34
|
*sin_theta = sycl::sin(theta) * mscale;
|
|
29
35
|
}
|
|
30
36
|
|
|
31
|
-
template<typename T, bool has_ff>
|
|
32
|
-
static void rope_norm(
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
37
|
-
item_ct1.get_local_id(1));
|
|
37
|
+
template <typename T, bool has_ff>
|
|
38
|
+
static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
|
39
|
+
const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
|
|
40
|
+
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
|
|
41
|
+
const sycl::nd_item<3> & item_ct1) {
|
|
42
|
+
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
|
|
38
43
|
|
|
39
44
|
if (i0 >= ne0) {
|
|
40
45
|
return;
|
|
41
46
|
}
|
|
42
47
|
|
|
43
|
-
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
44
|
-
item_ct1.get_local_id(2);
|
|
48
|
+
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
|
45
49
|
|
|
46
50
|
if (i0 >= n_dims) {
|
|
47
|
-
const int i = row*ne0 + i0;
|
|
51
|
+
const int i = row * ne0 + i0;
|
|
48
52
|
|
|
49
53
|
dst[i + 0] = x[i + 0];
|
|
50
54
|
dst[i + 1] = x[i + 1];
|
|
@@ -52,42 +56,43 @@ static void rope_norm(
|
|
|
52
56
|
return;
|
|
53
57
|
}
|
|
54
58
|
|
|
55
|
-
const int
|
|
56
|
-
const int
|
|
59
|
+
const int row0 = row % ne1;
|
|
60
|
+
const int channel0 = row / ne1;
|
|
57
61
|
|
|
58
|
-
const
|
|
62
|
+
const int i = row * ne0 + i0;
|
|
63
|
+
const int i2 = channel0 * s2 + row0 * s1 + i0;
|
|
59
64
|
|
|
60
|
-
const float
|
|
65
|
+
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
|
66
|
+
|
|
67
|
+
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
61
68
|
|
|
62
69
|
float cos_theta;
|
|
63
70
|
float sin_theta;
|
|
64
71
|
|
|
65
|
-
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
72
|
+
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
66
73
|
|
|
67
|
-
const float x0 = x[
|
|
68
|
-
const float x1 = x[
|
|
74
|
+
const float x0 = x[i2 + 0];
|
|
75
|
+
const float x1 = x[i2 + 1];
|
|
69
76
|
|
|
70
|
-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
|
71
|
-
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
|
77
|
+
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
|
|
78
|
+
dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
|
|
72
79
|
}
|
|
73
80
|
|
|
74
|
-
template<typename T, bool has_ff>
|
|
75
|
-
static void rope_neox(
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
|
80
|
-
item_ct1.get_local_id(1));
|
|
81
|
+
template <typename T, bool has_ff>
|
|
82
|
+
static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
|
83
|
+
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
|
84
|
+
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
|
|
85
|
+
const sycl::nd_item<3> & item_ct1) {
|
|
86
|
+
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
|
|
81
87
|
|
|
82
88
|
if (i0 >= ne0) {
|
|
83
89
|
return;
|
|
84
90
|
}
|
|
85
91
|
|
|
86
|
-
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
|
87
|
-
item_ct1.get_local_id(2);
|
|
92
|
+
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
|
88
93
|
|
|
89
94
|
if (i0 >= n_dims) {
|
|
90
|
-
const int i = row*ne0 + i0;
|
|
95
|
+
const int i = row * ne0 + i0;
|
|
91
96
|
|
|
92
97
|
dst[i + 0] = x[i + 0];
|
|
93
98
|
dst[i + 1] = x[i + 1];
|
|
@@ -95,38 +100,83 @@ static void rope_neox(
|
|
|
95
100
|
return;
|
|
96
101
|
}
|
|
97
102
|
|
|
98
|
-
const int
|
|
99
|
-
const int
|
|
103
|
+
const int row0 = row % ne1;
|
|
104
|
+
const int channel0 = row / ne1;
|
|
105
|
+
|
|
106
|
+
const int i = row * ne0 + i0 / 2;
|
|
107
|
+
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
|
|
100
108
|
|
|
101
|
-
const float theta_base = pos[
|
|
109
|
+
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
|
102
110
|
|
|
103
|
-
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
|
111
|
+
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
104
112
|
|
|
105
113
|
float cos_theta;
|
|
106
114
|
float sin_theta;
|
|
107
115
|
|
|
108
|
-
rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
116
|
+
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
117
|
+
|
|
118
|
+
const float x0 = x[i2 + 0];
|
|
119
|
+
const float x1 = x[i2 + n_dims / 2];
|
|
120
|
+
|
|
121
|
+
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
|
|
122
|
+
dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
template <typename T, bool has_ff>
|
|
126
|
+
static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
|
127
|
+
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
|
128
|
+
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
|
129
|
+
const float theta_scale, const float * freq_factors, const mrope_sections sections,
|
|
130
|
+
const sycl::nd_item<3> & item_ct1) {
|
|
131
|
+
// get index pos
|
|
132
|
+
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
|
|
133
|
+
if (i0 >= ne0) {
|
|
134
|
+
return;
|
|
135
|
+
}
|
|
136
|
+
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
|
137
|
+
const int row_x = row_dst % ne1;
|
|
138
|
+
const int channel_x = row_dst / ne1;
|
|
139
|
+
const int idst = (row_dst * ne0) + (i0 / 2);
|
|
140
|
+
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
|
141
|
+
|
|
142
|
+
const int sect_dims = sections.v[0] + sections.v[1];
|
|
143
|
+
const int sector = (i0 / 2) % sect_dims;
|
|
144
|
+
|
|
145
|
+
float theta_base = 0.0f;
|
|
146
|
+
if (sector < sections.v[0]) {
|
|
147
|
+
const int p = sector;
|
|
148
|
+
theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p);
|
|
149
|
+
} else {
|
|
150
|
+
// Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0]
|
|
151
|
+
const int p = sector - sections.v[0];
|
|
152
|
+
theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
|
|
153
|
+
}
|
|
109
154
|
|
|
110
|
-
const float
|
|
111
|
-
|
|
155
|
+
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
156
|
+
float cos_theta;
|
|
157
|
+
float sin_theta;
|
|
158
|
+
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
|
159
|
+
const float x0 = x[ix + 0];
|
|
160
|
+
const float x1 = x[ix + n_dims];
|
|
112
161
|
|
|
113
|
-
|
|
114
|
-
dst[
|
|
162
|
+
// store results in dst
|
|
163
|
+
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
|
164
|
+
dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
|
|
115
165
|
}
|
|
116
166
|
|
|
117
167
|
template <typename T>
|
|
118
|
-
static void rope_norm_sycl(
|
|
119
|
-
|
|
120
|
-
|
|
168
|
+
static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
|
|
169
|
+
const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
|
|
170
|
+
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
|
171
|
+
const float * freq_factors, queue_ptr stream) {
|
|
121
172
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
122
173
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
|
123
|
-
const int
|
|
174
|
+
const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
|
124
175
|
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
|
125
176
|
|
|
126
|
-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
177
|
+
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
|
127
178
|
|
|
128
|
-
dpct::has_capability_or_fail(stream->get_device(),
|
|
129
|
-
{sycl::aspect::fp16});
|
|
179
|
+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
|
130
180
|
|
|
131
181
|
if (freq_factors == nullptr) {
|
|
132
182
|
/*
|
|
@@ -134,82 +184,102 @@ static void rope_norm_sycl(
|
|
|
134
184
|
the limit. To get the device limit, query
|
|
135
185
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
|
136
186
|
*/
|
|
137
|
-
stream->parallel_for(
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
|
|
142
|
-
item_ct1);
|
|
143
|
-
});
|
|
187
|
+
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
|
188
|
+
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
|
189
|
+
theta_scale, freq_factors, item_ct1);
|
|
190
|
+
});
|
|
144
191
|
} else {
|
|
145
192
|
/*
|
|
146
193
|
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
|
|
147
194
|
the limit. To get the device limit, query
|
|
148
195
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
|
149
196
|
*/
|
|
150
|
-
stream->parallel_for(
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
|
|
155
|
-
item_ct1);
|
|
156
|
-
});
|
|
197
|
+
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
|
198
|
+
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
|
199
|
+
theta_scale, freq_factors, item_ct1);
|
|
200
|
+
});
|
|
157
201
|
}
|
|
158
202
|
}
|
|
159
203
|
|
|
160
204
|
template <typename T>
|
|
161
|
-
static void rope_neox_sycl(
|
|
162
|
-
|
|
163
|
-
|
|
205
|
+
static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
|
|
206
|
+
const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
|
|
207
|
+
const float freq_base, const float ext_factor, const float attn_factor,
|
|
208
|
+
const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
|
|
164
209
|
GGML_ASSERT(ne0 % 2 == 0);
|
|
165
210
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
|
166
|
-
const int
|
|
211
|
+
const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
|
167
212
|
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
|
168
213
|
|
|
169
|
-
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
214
|
+
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
|
170
215
|
|
|
171
|
-
dpct::has_capability_or_fail(stream->get_device(),
|
|
172
|
-
{sycl::aspect::fp16});
|
|
216
|
+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
|
173
217
|
|
|
174
218
|
if (freq_factors == nullptr) {
|
|
175
|
-
stream->parallel_for(
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
p_delta_rows, ext_factor, attn_factor,
|
|
180
|
-
corr_dims, theta_scale, freq_factors,
|
|
181
|
-
item_ct1);
|
|
182
|
-
});
|
|
219
|
+
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
|
220
|
+
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
|
221
|
+
theta_scale, freq_factors, item_ct1);
|
|
222
|
+
});
|
|
183
223
|
} else {
|
|
184
|
-
stream->parallel_for(
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
p_delta_rows, ext_factor, attn_factor,
|
|
189
|
-
corr_dims, theta_scale, freq_factors,
|
|
190
|
-
item_ct1);
|
|
191
|
-
});
|
|
224
|
+
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
|
225
|
+
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
|
226
|
+
theta_scale, freq_factors, item_ct1);
|
|
227
|
+
});
|
|
192
228
|
}
|
|
193
229
|
}
|
|
194
230
|
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
231
|
+
// rope vision
|
|
232
|
+
template <typename T>
|
|
233
|
+
static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
|
234
|
+
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
|
|
235
|
+
const float freq_scale, const float freq_base, const float ext_factor,
|
|
236
|
+
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
|
|
237
|
+
const mrope_sections sections, queue_ptr stream) {
|
|
238
|
+
GGML_ASSERT(ne0 % 2 == 0);
|
|
239
|
+
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
|
240
|
+
const int n_blocks_y = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
|
241
|
+
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
|
242
|
+
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
|
243
|
+
|
|
244
|
+
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
|
|
245
|
+
// Add FP16 capability check if T could be sycl::half
|
|
246
|
+
if constexpr (std::is_same_v<T, sycl::half>) {
|
|
247
|
+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
|
248
|
+
}
|
|
249
|
+
// launch kernel
|
|
250
|
+
if (freq_factors == nullptr) {
|
|
251
|
+
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
|
252
|
+
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
|
253
|
+
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
|
254
|
+
});
|
|
255
|
+
} else {
|
|
256
|
+
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
|
257
|
+
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
|
258
|
+
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
|
259
|
+
});
|
|
260
|
+
}
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
|
199
264
|
|
|
200
|
-
GGML_ASSERT(
|
|
265
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
|
201
266
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
|
202
|
-
GGML_ASSERT(
|
|
267
|
+
GGML_ASSERT(dst->src[0]->type == dst->type);
|
|
268
|
+
const int64_t ne00 = dst->src[0]->ne[0]; // head dims
|
|
269
|
+
const int64_t ne01 = dst->src[0]->ne[1]; // num heads
|
|
270
|
+
const int64_t ne02 = dst->src[0]->ne[2]; // num heads
|
|
271
|
+
const int64_t nr = ggml_nrows(dst->src[0]);
|
|
272
|
+
|
|
273
|
+
const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type);
|
|
274
|
+
const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type);
|
|
203
275
|
|
|
204
|
-
const int64_t ne00 = src0->ne[0];
|
|
205
|
-
const int64_t ne01 = src0->ne[1];
|
|
206
|
-
const int64_t nr = ggml_nrows(src0);
|
|
207
276
|
|
|
208
277
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
209
278
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
210
279
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
211
280
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
212
281
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
282
|
+
mrope_sections sections;
|
|
213
283
|
|
|
214
284
|
// RoPE alteration for extended context
|
|
215
285
|
float freq_base;
|
|
@@ -225,52 +295,68 @@ void ggml_sycl_op_rope(
|
|
|
225
295
|
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
226
296
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
227
297
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
298
|
+
memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
228
299
|
|
|
229
300
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
301
|
+
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
230
302
|
|
|
231
|
-
const int32_t * pos = (const int32_t *)
|
|
303
|
+
const int32_t * pos = (const int32_t *) dst->src[1]->data;
|
|
232
304
|
|
|
233
305
|
const float * freq_factors = nullptr;
|
|
234
|
-
if (
|
|
235
|
-
freq_factors = (const float *)
|
|
306
|
+
if (dst->src[2] != nullptr) {
|
|
307
|
+
freq_factors = (const float *) dst->src[2]->data;
|
|
236
308
|
}
|
|
237
309
|
|
|
238
310
|
rope_corr_dims corr_dims;
|
|
239
311
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
|
240
312
|
|
|
313
|
+
dpct::queue_ptr main_stream = ctx.stream();
|
|
314
|
+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
|
315
|
+
|
|
241
316
|
// compute
|
|
242
317
|
if (is_neox) {
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
attn_factor, corr_dims, freq_factors, main_stream
|
|
252
|
-
);
|
|
318
|
+
GGML_SYCL_DEBUG("%s: neox path\n", __func__);
|
|
319
|
+
if (dst->src[0]->type == GGML_TYPE_F32) {
|
|
320
|
+
rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
|
|
321
|
+
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
|
|
322
|
+
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
|
323
|
+
rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
|
|
324
|
+
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
|
325
|
+
main_stream);
|
|
253
326
|
} else {
|
|
254
327
|
GGML_ABORT("fatal error");
|
|
255
328
|
}
|
|
329
|
+
} else if (is_vision) {
|
|
330
|
+
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
|
|
331
|
+
if (dst->src[0]->type == GGML_TYPE_F16) {
|
|
332
|
+
rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
|
|
333
|
+
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
|
334
|
+
freq_factors, sections, main_stream);
|
|
335
|
+
} else if (dst->src[0]->type == GGML_TYPE_F32) {
|
|
336
|
+
rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
|
|
337
|
+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
|
338
|
+
main_stream);
|
|
339
|
+
} else {
|
|
340
|
+
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
|
341
|
+
}
|
|
256
342
|
} else {
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
attn_factor, corr_dims, freq_factors, main_stream
|
|
266
|
-
);
|
|
343
|
+
GGML_SYCL_DEBUG("%s: norm path\n", __func__);
|
|
344
|
+
if (dst->src[0]->type == GGML_TYPE_F32) {
|
|
345
|
+
rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
|
|
346
|
+
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
|
|
347
|
+
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
|
348
|
+
rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
|
|
349
|
+
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
|
350
|
+
main_stream);
|
|
267
351
|
} else {
|
|
268
352
|
GGML_ABORT("fatal error");
|
|
269
353
|
}
|
|
270
354
|
}
|
|
355
|
+
}
|
|
271
356
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
357
|
+
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
358
|
+
GGML_SYCL_DEBUG("call %s\n", __func__);
|
|
359
|
+
ggml_sycl_op_rope(ctx, dst);
|
|
360
|
+
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
|
276
361
|
}
|
|
362
|
+
|
|
@@ -15,8 +15,6 @@
|
|
|
15
15
|
|
|
16
16
|
#include "common.hpp"
|
|
17
17
|
|
|
18
|
-
void
|
|
19
|
-
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
|
20
|
-
const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream);
|
|
18
|
+
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
|
|
21
19
|
|
|
22
20
|
#endif // GGML_SYCL_ROPE_HPP
|