@novastera-oss/llamarn 0.2.7 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -2
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +24 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +5 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
- package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -43
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
- package/cpp/llama.cpp/src/llama-arch.h +36 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
- package/cpp/llama.cpp/src/llama-batch.h +105 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
- package/cpp/llama.cpp/src/llama-graph.h +78 -79
- package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
- package/cpp/llama.cpp/src/llama-hparams.h +11 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
- package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +21 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
- package/cpp/llama.cpp/src/llama-model.h +40 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
- package/cpp/llama.cpp/src/llama-vocab.h +42 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +5 -0
- package/ios/include/llama.h +8 -43
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -254,14 +254,13 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
|
|
254
254
|
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
|
255
255
|
if (ncols < 1024) {
|
|
256
256
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
257
|
-
stream
|
|
258
|
-
cgh
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
});
|
|
257
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
258
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
|
259
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
260
|
+
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
|
261
|
+
nullptr, WARP_SIZE);
|
|
262
|
+
});
|
|
263
|
+
});
|
|
265
264
|
}
|
|
266
265
|
else {
|
|
267
266
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
@@ -272,16 +271,15 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
|
|
272
271
|
the limit. To get the device limit, query
|
|
273
272
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
|
274
273
|
*/
|
|
275
|
-
stream
|
|
274
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
276
275
|
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
|
|
277
276
|
sycl::range<1>(work_group_size / WARP_SIZE), cgh);
|
|
278
|
-
cgh
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
});
|
|
277
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
|
278
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
279
|
+
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
|
280
|
+
get_pointer(s_sum_acc_ct1), work_group_size);
|
|
281
|
+
});
|
|
282
|
+
});
|
|
285
283
|
}
|
|
286
284
|
}
|
|
287
285
|
|
|
@@ -290,18 +288,14 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
|
|
290
288
|
const int ne_elements, queue_ptr stream, int device) {
|
|
291
289
|
if (group_size < 1024) {
|
|
292
290
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
293
|
-
stream
|
|
291
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
294
292
|
const float eps_ct4 = eps;
|
|
295
|
-
cgh
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
|
302
|
-
nullptr, WARP_SIZE);
|
|
303
|
-
});
|
|
304
|
-
});
|
|
293
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
|
|
294
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
295
|
+
group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1, nullptr,
|
|
296
|
+
WARP_SIZE);
|
|
297
|
+
});
|
|
298
|
+
});
|
|
305
299
|
}
|
|
306
300
|
else {
|
|
307
301
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
@@ -313,22 +307,18 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
|
|
313
307
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
|
314
308
|
*/
|
|
315
309
|
|
|
316
|
-
stream
|
|
310
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
317
311
|
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
|
318
312
|
cgh);
|
|
319
313
|
|
|
320
314
|
const float eps_ct4 = eps;
|
|
321
315
|
|
|
322
|
-
cgh
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
eps_ct4, item_ct1,
|
|
329
|
-
get_pointer(s_sum_acc_ct1), work_group_size);
|
|
330
|
-
});
|
|
331
|
-
});
|
|
316
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
|
|
317
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
318
|
+
group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
|
319
|
+
get_pointer(s_sum_acc_ct1), work_group_size);
|
|
320
|
+
});
|
|
321
|
+
});
|
|
332
322
|
}
|
|
333
323
|
}
|
|
334
324
|
|
|
@@ -340,14 +330,13 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
|
|
340
330
|
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
|
341
331
|
if (ncols < 1024) {
|
|
342
332
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
343
|
-
stream
|
|
344
|
-
cgh
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
});
|
|
333
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
334
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
|
335
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
336
|
+
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
|
337
|
+
nullptr, WARP_SIZE);
|
|
338
|
+
});
|
|
339
|
+
});
|
|
351
340
|
}
|
|
352
341
|
else {
|
|
353
342
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
@@ -358,16 +347,15 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
|
|
358
347
|
the limit. To get the device limit, query
|
|
359
348
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
|
360
349
|
*/
|
|
361
|
-
stream
|
|
350
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
362
351
|
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
|
363
352
|
cgh);
|
|
364
|
-
cgh
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
});
|
|
353
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
|
354
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
355
|
+
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
|
356
|
+
get_pointer(s_sum_acc_ct1), work_group_size);
|
|
357
|
+
});
|
|
358
|
+
});
|
|
371
359
|
}
|
|
372
360
|
}
|
|
373
361
|
|
|
@@ -378,16 +366,12 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
378
366
|
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
|
379
367
|
if (ncols < 1024) {
|
|
380
368
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
381
|
-
stream
|
|
382
|
-
cgh
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
|
388
|
-
nullptr, WARP_SIZE);
|
|
389
|
-
});
|
|
390
|
-
});
|
|
369
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
370
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
|
371
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
372
|
+
l2_norm_f32(x, dst, ncols, eps, item_ct1, nullptr, WARP_SIZE);
|
|
373
|
+
});
|
|
374
|
+
});
|
|
391
375
|
}
|
|
392
376
|
else {
|
|
393
377
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
@@ -398,18 +382,15 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
398
382
|
the limit. To get the device limit, query
|
|
399
383
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
|
400
384
|
*/
|
|
401
|
-
stream
|
|
385
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
402
386
|
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
|
403
387
|
cgh);
|
|
404
|
-
cgh
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
get_pointer(s_sum_acc_ct1), work_group_size);
|
|
411
|
-
});
|
|
412
|
-
});
|
|
388
|
+
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
|
389
|
+
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
390
|
+
l2_norm_f32(x, dst, ncols, eps, item_ct1, get_pointer(s_sum_acc_ct1),
|
|
391
|
+
work_group_size);
|
|
392
|
+
});
|
|
393
|
+
});
|
|
413
394
|
}
|
|
414
395
|
}
|
|
415
396
|
|
|
@@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
|
|
|
47
47
|
|
|
48
48
|
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
|
49
49
|
|
|
50
|
-
if (i0 >= n_dims) {
|
|
51
|
-
const int i = row * ne0 + i0;
|
|
52
|
-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
|
53
|
-
return;
|
|
54
|
-
}
|
|
55
|
-
|
|
56
50
|
const int row0 = row % ne1;
|
|
57
51
|
const int channel0 = row / ne1;
|
|
58
52
|
|
|
59
53
|
const int i = row * ne0 + i0;
|
|
60
54
|
const int i2 = channel0 * s2 + row0 * s1 + i0;
|
|
61
55
|
|
|
56
|
+
if (i0 >= n_dims) {
|
|
57
|
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
|
|
58
|
+
return;
|
|
59
|
+
}
|
|
60
|
+
|
|
62
61
|
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
|
63
62
|
|
|
64
63
|
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
@@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
|
|
|
88
87
|
|
|
89
88
|
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
|
90
89
|
|
|
91
|
-
if (i0 >= n_dims) {
|
|
92
|
-
const int i = row * ne0 + i0;
|
|
93
|
-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
|
94
|
-
return;
|
|
95
|
-
}
|
|
96
|
-
|
|
97
90
|
const int row0 = row % ne1;
|
|
98
91
|
const int channel0 = row / ne1;
|
|
99
92
|
|
|
100
93
|
const int i = row * ne0 + i0 / 2;
|
|
101
94
|
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
|
|
102
95
|
|
|
96
|
+
if (i0 >= n_dims) {
|
|
97
|
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
|
|
98
|
+
return;
|
|
99
|
+
}
|
|
100
|
+
|
|
103
101
|
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
|
104
102
|
|
|
105
103
|
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
|
@@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
|
|
|
129
127
|
}
|
|
130
128
|
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
|
131
129
|
|
|
132
|
-
if (i0 >= n_dims) {
|
|
133
|
-
const int i = row_dst*ne0 + i0;
|
|
134
|
-
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
|
135
|
-
return;
|
|
136
|
-
}
|
|
137
|
-
|
|
138
130
|
const int row_x = row_dst % ne1;
|
|
139
131
|
const int channel_x = row_dst / ne1;
|
|
140
132
|
const int idst = (row_dst * ne0) + (i0 / 2);
|
|
141
133
|
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
|
142
134
|
|
|
135
|
+
if (i0 >= n_dims) {
|
|
136
|
+
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
|
|
137
|
+
return;
|
|
138
|
+
}
|
|
139
|
+
|
|
143
140
|
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
|
144
141
|
const int sec_w = sections.v[1] + sections.v[0];
|
|
145
142
|
const int sector = (i0 / 2) % sect_dims;
|
|
@@ -235,20 +232,22 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|
|
235
232
|
the limit. To get the device limit, query
|
|
236
233
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
|
237
234
|
*/
|
|
238
|
-
stream
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
235
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
236
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
237
|
+
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
238
|
+
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
|
239
|
+
});
|
|
242
240
|
} else {
|
|
243
241
|
/*
|
|
244
242
|
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
|
|
245
243
|
the limit. To get the device limit, query
|
|
246
244
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
|
247
245
|
*/
|
|
248
|
-
stream
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
246
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
247
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
248
|
+
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
249
|
+
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
|
250
|
+
});
|
|
252
251
|
}
|
|
253
252
|
}
|
|
254
253
|
|
|
@@ -267,15 +266,17 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|
|
267
266
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
|
268
267
|
|
|
269
268
|
if (freq_factors == nullptr) {
|
|
270
|
-
stream
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
269
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
270
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
271
|
+
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
272
|
+
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
|
273
|
+
});
|
|
274
274
|
} else {
|
|
275
|
-
stream
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
275
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
276
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
277
|
+
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
|
278
|
+
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
|
279
|
+
});
|
|
279
280
|
}
|
|
280
281
|
}
|
|
281
282
|
|
|
@@ -298,12 +299,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
|
|
298
299
|
}
|
|
299
300
|
// launch kernel
|
|
300
301
|
if (freq_factors == nullptr) {
|
|
301
|
-
stream
|
|
302
|
+
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
|
302
303
|
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
|
303
304
|
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
|
304
305
|
});
|
|
305
306
|
} else {
|
|
306
|
-
stream
|
|
307
|
+
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
|
307
308
|
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
|
308
309
|
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
|
309
310
|
});
|
|
@@ -333,12 +334,12 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
|
|
333
334
|
}
|
|
334
335
|
// launch kernel
|
|
335
336
|
if (freq_factors == nullptr) {
|
|
336
|
-
stream
|
|
337
|
+
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
|
337
338
|
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
|
338
339
|
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
|
339
340
|
});
|
|
340
341
|
} else {
|
|
341
|
-
stream
|
|
342
|
+
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
|
342
343
|
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
|
343
344
|
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
|
344
345
|
});
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
#include "set_rows.hpp"
|
|
2
|
+
|
|
3
|
+
namespace utils {
|
|
4
|
+
template<typename T>
|
|
5
|
+
static constexpr bool is_arithmetic_v() {
|
|
6
|
+
return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> || std::is_same_v<T, sycl::ext::oneapi::bfloat16>;
|
|
7
|
+
}
|
|
8
|
+
}
|
|
9
|
+
template<typename TIn, typename TOut>
|
|
10
|
+
static inline std::enable_if_t<utils::is_arithmetic_v<TIn>() && utils::is_arithmetic_v<TOut>(), void>
|
|
11
|
+
convert (const char* src, char* dst) {
|
|
12
|
+
auto src_val = *reinterpret_cast<const TIn*>(src);
|
|
13
|
+
auto dst_val = sycl::vec<TIn, 1>(src_val).template convert<TOut, sycl::rounding_mode::automatic>()[0];
|
|
14
|
+
*reinterpret_cast<TOut*>(dst) = dst_val;;
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
template<typename TIn, typename TOut>
|
|
18
|
+
static void k_set_rows(
|
|
19
|
+
const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
|
|
20
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne11, const int64_t ne12,
|
|
21
|
+
const size_t nb01, const size_t nb02, const size_t nb03,
|
|
22
|
+
const size_t nb10, const size_t nb11, const size_t nb12,
|
|
23
|
+
const size_t nb1, const size_t nb2, const size_t nb3,
|
|
24
|
+
const size_t src_type_size, const size_t dst_type_size,
|
|
25
|
+
const sycl::nd_item<3> & item_ct1) {
|
|
26
|
+
|
|
27
|
+
const int i03 = item_ct1.get_group(0);
|
|
28
|
+
const int i02 = item_ct1.get_group(1);
|
|
29
|
+
const int i01 = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); // Row index
|
|
30
|
+
|
|
31
|
+
if (i01 >= ne01) {
|
|
32
|
+
return;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
const int i12 = i03 % ne12;
|
|
36
|
+
const int i11 = i02 % ne11;
|
|
37
|
+
const int i10 = i01;
|
|
38
|
+
|
|
39
|
+
const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12}));
|
|
40
|
+
|
|
41
|
+
const char * src0_row = src0 + calculate_offset<3>({nb01, nb02, nb03}, {i01, i02, i03});
|
|
42
|
+
char * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
|
|
43
|
+
|
|
44
|
+
for (int col = item_ct1.get_local_id(0); col < ne00; col += item_ct1.get_local_range(0)) {
|
|
45
|
+
const char * src_elem = src0_row + col * src_type_size;
|
|
46
|
+
char * dst_elem = dst_row_ptr + col * dst_type_size;
|
|
47
|
+
convert<TIn, TOut>(src_elem, dst_elem);
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
template<typename TIn, typename TOut>
|
|
52
|
+
static void set_rows_sycl(
|
|
53
|
+
const char * src0_d, const int64_t * src1_d, char * dst_d,
|
|
54
|
+
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
|
55
|
+
const int64_t ne11, const int64_t ne12, const size_t nb01, const size_t nb02, const size_t nb03,
|
|
56
|
+
const size_t nb10, const size_t nb11, const size_t nb12,
|
|
57
|
+
const size_t nb1, const size_t nb2, const size_t nb3,
|
|
58
|
+
const size_t src_type_size, const size_t dst_type_size,
|
|
59
|
+
queue_ptr stream) {
|
|
60
|
+
|
|
61
|
+
constexpr int max_threads_per_row = 64; // KEEPING 64 for now
|
|
62
|
+
const int threads_per_row = std::min((int)ne00, max_threads_per_row);
|
|
63
|
+
|
|
64
|
+
constexpr int max_threads_per_block = 64;
|
|
65
|
+
const int rows_per_block = std::max(1, max_threads_per_block / threads_per_row);
|
|
66
|
+
|
|
67
|
+
const sycl::range<3> block_size(1, rows_per_block, threads_per_row);
|
|
68
|
+
const sycl::range<3> grid_size(ne03, ne02, (ne01 + rows_per_block - 1) / rows_per_block);
|
|
69
|
+
|
|
70
|
+
sycl_parallel_for(
|
|
71
|
+
stream,
|
|
72
|
+
sycl::nd_range<3>(grid_size * block_size, block_size),
|
|
73
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
74
|
+
k_set_rows<TIn, TOut>(
|
|
75
|
+
src0_d, src1_d, dst_d,
|
|
76
|
+
ne00, ne01, ne11, ne12,
|
|
77
|
+
nb01, nb02, nb03,
|
|
78
|
+
nb10, nb11, nb12,
|
|
79
|
+
nb1, nb2, nb3,
|
|
80
|
+
src_type_size, dst_type_size,
|
|
81
|
+
item_ct1
|
|
82
|
+
);
|
|
83
|
+
}
|
|
84
|
+
);
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
89
|
+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
|
90
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
91
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
92
|
+
|
|
93
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
|
94
|
+
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I64);
|
|
95
|
+
|
|
96
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
97
|
+
|
|
98
|
+
const int64_t * src1_dd = static_cast<const int64_t *>(src1->data);
|
|
99
|
+
|
|
100
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
101
|
+
switch (dst->type) {
|
|
102
|
+
case GGML_TYPE_F32:
|
|
103
|
+
set_rows_sycl<float, float>(
|
|
104
|
+
(const char *)src0->data, src1_dd, (char *)dst->data,
|
|
105
|
+
ne00, ne01, ne02, ne03,
|
|
106
|
+
ne11, ne12,
|
|
107
|
+
nb01, nb02, nb03,
|
|
108
|
+
nb10, nb11, nb12,
|
|
109
|
+
nb1, nb2, nb3,
|
|
110
|
+
sizeof(float), sizeof(float),
|
|
111
|
+
stream
|
|
112
|
+
);
|
|
113
|
+
break;
|
|
114
|
+
case GGML_TYPE_F16:
|
|
115
|
+
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
|
116
|
+
set_rows_sycl<float, sycl::half>(
|
|
117
|
+
(const char *)src0->data, src1_dd, (char *)dst->data,
|
|
118
|
+
ne00, ne01, ne02, ne03,
|
|
119
|
+
ne11, ne12,
|
|
120
|
+
nb01, nb02, nb03,
|
|
121
|
+
nb10, nb11, nb12,
|
|
122
|
+
nb1, nb2, nb3,
|
|
123
|
+
sizeof(float), sizeof(sycl::half),
|
|
124
|
+
stream
|
|
125
|
+
);
|
|
126
|
+
break;
|
|
127
|
+
default:
|
|
128
|
+
GGML_ABORT("Unsupported tensor type!");
|
|
129
|
+
break;
|
|
130
|
+
}
|
|
131
|
+
}
|
|
@@ -127,11 +127,11 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
|
|
|
127
127
|
const int nrows_y, const float scale, const float max_bias, const float m0,
|
|
128
128
|
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
|
129
129
|
const size_t n_local_scratch, queue_ptr stream) {
|
|
130
|
-
stream
|
|
130
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
131
131
|
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
|
|
132
132
|
|
|
133
|
-
|
|
134
|
-
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
133
|
+
sycl_parallel_for(
|
|
134
|
+
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
135
135
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
136
136
|
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
|
137
137
|
nrows_y, scale, max_bias, m0,
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
#include "sycl_hw.hpp"
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
// TODO: currently not used
|
|
4
|
+
/*
|
|
4
5
|
sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
|
|
5
6
|
sycl_hw_info res;
|
|
6
7
|
int32_t id = device_ptr->get_info<sycl::ext::intel::info::device::device_id>();
|
|
@@ -11,3 +12,4 @@ sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
|
|
|
11
12
|
|
|
12
13
|
return res;
|
|
13
14
|
}
|
|
15
|
+
*/
|
|
@@ -10,6 +10,8 @@
|
|
|
10
10
|
|
|
11
11
|
namespace syclex = sycl::ext::oneapi::experimental;
|
|
12
12
|
|
|
13
|
+
// TODO: currently not used
|
|
14
|
+
/*
|
|
13
15
|
struct sycl_hw_info {
|
|
14
16
|
syclex::architecture arch;
|
|
15
17
|
int32_t device_id;
|
|
@@ -18,6 +20,7 @@ struct sycl_hw_info {
|
|
|
18
20
|
bool is_in_vector(std::vector<int> &vec, int item);
|
|
19
21
|
|
|
20
22
|
sycl_hw_info get_device_hw_info(sycl::device *device_ptr);
|
|
23
|
+
*/
|
|
21
24
|
|
|
22
25
|
|
|
23
26
|
#endif // SYCL_HW_HPP
|
|
@@ -45,14 +45,9 @@ static void timestep_embedding_f32_sycl(
|
|
|
45
45
|
int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
|
|
46
46
|
sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
|
|
47
47
|
sycl::range<3> gridDim(1, ne00, num_blocks);
|
|
48
|
-
stream
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
[=](sycl::nd_item<3> item_ct1) {
|
|
52
|
-
timestep_embedding_f32(
|
|
53
|
-
x, dst, nb1, dim, max_period, item_ct1
|
|
54
|
-
);
|
|
55
|
-
});
|
|
48
|
+
sycl_parallel_for(stream, sycl::nd_range<3>(gridDim * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
|
49
|
+
timestep_embedding_f32(x, dst, nb1, dim, max_period, item_ct1);
|
|
50
|
+
});
|
|
56
51
|
}
|
|
57
52
|
|
|
58
53
|
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
@@ -207,12 +207,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
|
207
207
|
|
|
208
208
|
// Submit kernel
|
|
209
209
|
if (C / H == WKV_BLOCK_SIZE) {
|
|
210
|
-
stream
|
|
210
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
211
211
|
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
|
212
212
|
|
|
213
|
-
|
|
214
|
-
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
215
|
-
[=](sycl::nd_item<3> item_ct1) {
|
|
213
|
+
sycl_parallel_for(
|
|
214
|
+
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
|
216
215
|
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
|
|
217
216
|
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
|
218
217
|
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
|
@@ -220,12 +219,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
|
220
219
|
});
|
|
221
220
|
});
|
|
222
221
|
} else {
|
|
223
|
-
stream
|
|
222
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
224
223
|
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
|
225
224
|
|
|
226
|
-
|
|
227
|
-
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
228
|
-
[=](sycl::nd_item<3> item_ct1) {
|
|
225
|
+
sycl_parallel_for(
|
|
226
|
+
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
|
229
227
|
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
|
230
228
|
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
|
231
229
|
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
|
@@ -264,12 +262,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
|
264
262
|
|
|
265
263
|
// Submit kernel
|
|
266
264
|
if (C / H == WKV_BLOCK_SIZE) {
|
|
267
|
-
stream
|
|
265
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
268
266
|
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
|
269
267
|
|
|
270
|
-
|
|
271
|
-
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
272
|
-
[=](sycl::nd_item<3> item_ct1) {
|
|
268
|
+
sycl_parallel_for(
|
|
269
|
+
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
|
273
270
|
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
|
|
274
271
|
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
|
275
272
|
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
|
@@ -277,12 +274,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
|
|
277
274
|
});
|
|
278
275
|
});
|
|
279
276
|
} else {
|
|
280
|
-
stream
|
|
277
|
+
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
281
278
|
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
|
282
279
|
|
|
283
|
-
|
|
284
|
-
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
285
|
-
[=](sycl::nd_item<3> item_ct1) {
|
|
280
|
+
sycl_parallel_for(
|
|
281
|
+
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
|
286
282
|
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
|
287
283
|
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
|
288
284
|
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|