@novastera-oss/llamarn 0.2.9 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +0 -1
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +8 -8
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +62 -1
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +22 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +1250 -43
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +21 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +13 -3
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +173 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +44 -38
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +126 -8
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +138 -18
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +11 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +28 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1206 -163
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +6 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +36 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +142 -9
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +31 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +86 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -64
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +47 -60
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +29 -42
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +46 -59
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +38 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +23 -36
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +255 -99
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -695
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +104 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +27 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +80 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +48 -12
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +572 -106
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +599 -105
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +800 -42
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +693 -1034
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +191 -55
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +15 -18
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +991 -307
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +59 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +17 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +18 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +84 -9
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +386 -67
- package/cpp/llama.cpp/ggml/src/gguf.cpp +8 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +307 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +8 -2
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +122 -47
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +12 -3
- package/cpp/llama.cpp/include/llama.h +15 -47
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +316 -3
- package/cpp/llama.cpp/src/llama-arch.h +23 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +103 -71
- package/cpp/llama.cpp/src/llama-batch.h +31 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +58 -1
- package/cpp/llama.cpp/src/llama-chat.h +3 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +310 -211
- package/cpp/llama.cpp/src/llama-graph.h +184 -122
- package/cpp/llama.cpp/src/llama-hparams.cpp +47 -1
- package/cpp/llama.cpp/src/llama-hparams.h +13 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +38 -22
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +7 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +849 -304
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +143 -47
- package/cpp/llama.cpp/src/llama-kv-cells.h +62 -10
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +10 -4
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +3 -1
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +36 -11
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +3 -0
- package/cpp/llama.cpp/src/llama-model.cpp +3545 -719
- package/cpp/llama.cpp/src/llama-model.h +21 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +376 -10
- package/cpp/llama.cpp/src/llama-vocab.h +43 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +22 -4
- package/ios/include/llama.h +15 -47
- package/ios/libs/llama.xcframework/Info.plist +13 -13
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3766
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -4890
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -4861
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3764
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -4926
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -4897
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3794
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +173 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -47
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -109,6 +109,7 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
|
|
109
109
|
}
|
|
110
110
|
|
|
111
111
|
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
|
112
|
+
#pragma METAL fp math_mode(safe)
|
|
112
113
|
float amax = 0.0f; // absolute max
|
|
113
114
|
float max = 0.0f;
|
|
114
115
|
|
|
@@ -138,6 +139,7 @@ void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
|
|
138
139
|
}
|
|
139
140
|
|
|
140
141
|
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
|
142
|
+
#pragma METAL fp math_mode(safe)
|
|
141
143
|
float min = FLT_MAX;
|
|
142
144
|
float max = -FLT_MAX;
|
|
143
145
|
|
|
@@ -166,6 +168,7 @@ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
|
|
166
168
|
}
|
|
167
169
|
|
|
168
170
|
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
|
171
|
+
#pragma METAL fp math_mode(safe)
|
|
169
172
|
float amax = 0.0f; // absolute max
|
|
170
173
|
float max = 0.0f;
|
|
171
174
|
|
|
@@ -203,6 +206,7 @@ void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
|
|
203
206
|
}
|
|
204
207
|
|
|
205
208
|
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
|
209
|
+
#pragma METAL fp math_mode(safe)
|
|
206
210
|
float max = src[0];
|
|
207
211
|
float min = src[0];
|
|
208
212
|
|
|
@@ -239,6 +243,7 @@ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
|
|
239
243
|
}
|
|
240
244
|
|
|
241
245
|
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
|
246
|
+
#pragma METAL fp math_mode(safe)
|
|
242
247
|
float amax = 0.0f; // absolute max
|
|
243
248
|
float max = 0.0f;
|
|
244
249
|
|
|
@@ -458,6 +463,7 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
|
|
|
458
463
|
}
|
|
459
464
|
|
|
460
465
|
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
|
466
|
+
#pragma METAL fp math_mode(safe)
|
|
461
467
|
float amax = 0.0f; // absolute max
|
|
462
468
|
|
|
463
469
|
for (int j = 0; j < QK8_0; j++) {
|
|
@@ -826,7 +832,8 @@ enum ggml_sort_order {
|
|
|
826
832
|
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
|
|
827
833
|
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
|
828
834
|
// cons: not very efficient
|
|
829
|
-
|
|
835
|
+
template <int F>
|
|
836
|
+
kernel void kernel_add_fuse_impl(
|
|
830
837
|
constant ggml_metal_kargs_bin & args,
|
|
831
838
|
device const char * src0,
|
|
832
839
|
device const char * src1,
|
|
@@ -842,16 +849,39 @@ kernel void kernel_add(
|
|
|
842
849
|
const int i12 = i02%args.ne12;
|
|
843
850
|
const int i11 = i01%args.ne11;
|
|
844
851
|
|
|
845
|
-
device const
|
|
846
|
-
device
|
|
847
|
-
|
|
852
|
+
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
|
853
|
+
device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
|
854
|
+
|
|
855
|
+
device const float * src1_ptr[F];
|
|
856
|
+
for (short j = 0; j < F; ++j) {
|
|
857
|
+
src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
|
858
|
+
}
|
|
848
859
|
|
|
849
860
|
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
850
861
|
const int i10 = i0%args.ne10;
|
|
851
|
-
|
|
862
|
+
|
|
863
|
+
float res = src0_ptr[i0];
|
|
864
|
+
|
|
865
|
+
#pragma unroll
|
|
866
|
+
for (short j = 0; j < F; ++j) {
|
|
867
|
+
res += src1_ptr[j][i10];
|
|
868
|
+
}
|
|
869
|
+
|
|
870
|
+
dst_ptr[i0] = res;
|
|
852
871
|
}
|
|
853
872
|
}
|
|
854
873
|
|
|
874
|
+
typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
|
|
875
|
+
|
|
876
|
+
template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
|
|
877
|
+
template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
|
|
878
|
+
template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
|
|
879
|
+
template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
|
|
880
|
+
template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
|
|
881
|
+
template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
|
|
882
|
+
template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
|
|
883
|
+
template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
|
|
884
|
+
|
|
855
885
|
kernel void kernel_sub(
|
|
856
886
|
constant ggml_metal_kargs_bin & args,
|
|
857
887
|
device const char * src0,
|
|
@@ -869,7 +899,7 @@ kernel void kernel_sub(
|
|
|
869
899
|
const int i11 = i01%args.ne11;
|
|
870
900
|
|
|
871
901
|
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
|
872
|
-
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
|
902
|
+
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
|
873
903
|
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
874
904
|
|
|
875
905
|
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
@@ -894,9 +924,9 @@ kernel void kernel_mul(
|
|
|
894
924
|
const int i12 = i02%args.ne12;
|
|
895
925
|
const int i11 = i01%args.ne11;
|
|
896
926
|
|
|
897
|
-
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
|
|
898
|
-
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
|
899
|
-
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
|
|
927
|
+
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
|
928
|
+
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
|
929
|
+
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
900
930
|
|
|
901
931
|
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
902
932
|
const int i10 = i0%args.ne10;
|
|
@@ -920,9 +950,9 @@ kernel void kernel_div(
|
|
|
920
950
|
const int i12 = i02%args.ne12;
|
|
921
951
|
const int i11 = i01%args.ne11;
|
|
922
952
|
|
|
923
|
-
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
|
|
924
|
-
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
|
925
|
-
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
|
|
953
|
+
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
|
954
|
+
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
|
|
955
|
+
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
|
926
956
|
|
|
927
957
|
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
|
928
958
|
const int i10 = i0%args.ne10;
|
|
@@ -964,60 +994,161 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat
|
|
|
964
994
|
|
|
965
995
|
// assumption: src1 is a row
|
|
966
996
|
// broadcast src1 into src0
|
|
967
|
-
|
|
997
|
+
template <short F>
|
|
998
|
+
kernel void kernel_add_row_c4_fuse_impl(
|
|
968
999
|
constant ggml_metal_kargs_bin & args,
|
|
969
|
-
device const
|
|
970
|
-
device const
|
|
971
|
-
device
|
|
1000
|
+
device const char * src0,
|
|
1001
|
+
device const char * src1,
|
|
1002
|
+
device char * dst,
|
|
972
1003
|
uint tpig[[thread_position_in_grid]]) {
|
|
1004
|
+
|
|
973
1005
|
const uint nb = args.ne00/4;
|
|
974
|
-
|
|
1006
|
+
const uint i = tpig % nb;
|
|
1007
|
+
|
|
1008
|
+
device const float4 * src0_row = (device const float4 *) (src0);
|
|
1009
|
+
device float4 * dst_row = (device float4 *) (dst);
|
|
1010
|
+
|
|
1011
|
+
device const float4 * src1_row[F];
|
|
1012
|
+
for (short j = 0; j < F; ++j) {
|
|
1013
|
+
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
|
1014
|
+
}
|
|
1015
|
+
|
|
1016
|
+
float4 res = src0_row[tpig];
|
|
1017
|
+
|
|
1018
|
+
#pragma unroll(F)
|
|
1019
|
+
for (short j = 0; j < F; ++j) {
|
|
1020
|
+
res += src1_row[j][i];
|
|
1021
|
+
}
|
|
1022
|
+
|
|
1023
|
+
dst_row[tpig] = res;
|
|
975
1024
|
}
|
|
976
1025
|
|
|
977
|
-
|
|
1026
|
+
typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
|
|
1027
|
+
|
|
1028
|
+
template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
|
|
1029
|
+
template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
|
|
1030
|
+
template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
|
|
1031
|
+
template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
|
|
1032
|
+
template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
|
|
1033
|
+
template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
|
|
1034
|
+
template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
|
|
1035
|
+
template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
|
|
1036
|
+
|
|
1037
|
+
template <short F>
|
|
1038
|
+
kernel void kernel_sub_row_c4_fuse_impl(
|
|
978
1039
|
constant ggml_metal_kargs_bin & args,
|
|
979
|
-
device const
|
|
980
|
-
device const
|
|
981
|
-
device
|
|
1040
|
+
device const char * src0,
|
|
1041
|
+
device const char * src1,
|
|
1042
|
+
device char * dst,
|
|
982
1043
|
uint tpig[[thread_position_in_grid]]) {
|
|
1044
|
+
|
|
983
1045
|
const uint nb = args.ne00/4;
|
|
984
|
-
|
|
1046
|
+
const uint i = tpig % nb;
|
|
1047
|
+
|
|
1048
|
+
device const float4 * src0_row = (device const float4 *) (src0);
|
|
1049
|
+
device float4 * dst_row = (device float4 *) (dst);
|
|
1050
|
+
|
|
1051
|
+
device const float4 * src1_row[F];
|
|
1052
|
+
for (short j = 0; j < F; ++j) {
|
|
1053
|
+
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
|
1054
|
+
}
|
|
1055
|
+
|
|
1056
|
+
float4 res = src0_row[tpig];
|
|
1057
|
+
|
|
1058
|
+
#pragma unroll(F)
|
|
1059
|
+
for (short j = 0; j < F; ++j) {
|
|
1060
|
+
res -= src1_row[j][i];
|
|
1061
|
+
}
|
|
1062
|
+
|
|
1063
|
+
dst_row[tpig] = res;
|
|
985
1064
|
}
|
|
986
1065
|
|
|
987
|
-
|
|
1066
|
+
typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
|
|
1067
|
+
|
|
1068
|
+
template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
|
|
1069
|
+
|
|
1070
|
+
template <short F>
|
|
1071
|
+
kernel void kernel_mul_row_c4_fuse_impl(
|
|
988
1072
|
constant ggml_metal_kargs_bin & args,
|
|
989
|
-
device const
|
|
990
|
-
device const
|
|
991
|
-
device
|
|
1073
|
+
device const char * src0,
|
|
1074
|
+
device const char * src1,
|
|
1075
|
+
device char * dst,
|
|
992
1076
|
uint tpig[[thread_position_in_grid]]) {
|
|
1077
|
+
|
|
993
1078
|
const uint nb = args.ne00/4;
|
|
994
|
-
|
|
1079
|
+
const uint i = tpig % nb;
|
|
1080
|
+
|
|
1081
|
+
device const float4 * src0_row = (device const float4 *) (src0);
|
|
1082
|
+
device float4 * dst_row = (device float4 *) (dst);
|
|
1083
|
+
|
|
1084
|
+
device const float4 * src1_row[F];
|
|
1085
|
+
for (short j = 0; j < F; ++j) {
|
|
1086
|
+
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
|
1087
|
+
}
|
|
1088
|
+
|
|
1089
|
+
float4 res = src0_row[tpig];
|
|
1090
|
+
|
|
1091
|
+
#pragma unroll(F)
|
|
1092
|
+
for (short j = 0; j < F; ++j) {
|
|
1093
|
+
res *= src1_row[j][i];
|
|
1094
|
+
}
|
|
1095
|
+
|
|
1096
|
+
dst_row[tpig] = res;
|
|
995
1097
|
}
|
|
996
1098
|
|
|
997
|
-
|
|
1099
|
+
typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
|
|
1100
|
+
|
|
1101
|
+
template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
|
|
1102
|
+
|
|
1103
|
+
template <short F>
|
|
1104
|
+
kernel void kernel_div_row_c4_fuse_impl(
|
|
998
1105
|
constant ggml_metal_kargs_bin & args,
|
|
999
|
-
device const
|
|
1000
|
-
device const
|
|
1001
|
-
device
|
|
1106
|
+
device const char * src0,
|
|
1107
|
+
device const char * src1,
|
|
1108
|
+
device char * dst,
|
|
1002
1109
|
uint tpig[[thread_position_in_grid]]) {
|
|
1110
|
+
|
|
1003
1111
|
const uint nb = args.ne00/4;
|
|
1004
|
-
|
|
1112
|
+
const uint i = tpig % nb;
|
|
1113
|
+
|
|
1114
|
+
device const float4 * src0_row = (device const float4 *) (src0);
|
|
1115
|
+
device float4 * dst_row = (device float4 *) (dst);
|
|
1116
|
+
|
|
1117
|
+
device const float4 * src1_row[F];
|
|
1118
|
+
for (short j = 0; j < F; ++j) {
|
|
1119
|
+
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
|
1120
|
+
}
|
|
1121
|
+
|
|
1122
|
+
float4 res = src0_row[tpig];
|
|
1123
|
+
|
|
1124
|
+
#pragma unroll(F)
|
|
1125
|
+
for (short j = 0; j < F; ++j) {
|
|
1126
|
+
res /= src1_row[j][i];
|
|
1127
|
+
}
|
|
1128
|
+
|
|
1129
|
+
dst_row[tpig] = res;
|
|
1005
1130
|
}
|
|
1006
1131
|
|
|
1132
|
+
typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
|
|
1133
|
+
|
|
1134
|
+
template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
|
|
1135
|
+
|
|
1007
1136
|
kernel void kernel_scale(
|
|
1008
1137
|
device const float * src0,
|
|
1009
1138
|
device float * dst,
|
|
1010
1139
|
constant float & scale,
|
|
1140
|
+
constant float & bias,
|
|
1011
1141
|
uint tpig[[thread_position_in_grid]]) {
|
|
1012
|
-
dst[tpig] = src0[tpig] * scale;
|
|
1142
|
+
dst[tpig] = src0[tpig] * scale + bias;
|
|
1013
1143
|
}
|
|
1014
1144
|
|
|
1015
1145
|
kernel void kernel_scale_4(
|
|
1016
1146
|
device const float4 * src0,
|
|
1017
1147
|
device float4 * dst,
|
|
1018
1148
|
constant float & scale,
|
|
1149
|
+
constant float & bias,
|
|
1019
1150
|
uint tpig[[thread_position_in_grid]]) {
|
|
1020
|
-
dst[tpig] = src0[tpig] * scale;
|
|
1151
|
+
dst[tpig] = src0[tpig] * scale + bias;
|
|
1021
1152
|
}
|
|
1022
1153
|
|
|
1023
1154
|
kernel void kernel_clamp(
|
|
@@ -1191,6 +1322,159 @@ kernel void kernel_neg(
|
|
|
1191
1322
|
dst[tpig] = -src0[tpig];
|
|
1192
1323
|
}
|
|
1193
1324
|
|
|
1325
|
+
kernel void kernel_abs(
|
|
1326
|
+
device const float * src0,
|
|
1327
|
+
device float * dst,
|
|
1328
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1329
|
+
dst[tpig] = fabs(src0[tpig]);
|
|
1330
|
+
}
|
|
1331
|
+
|
|
1332
|
+
kernel void kernel_sgn(
|
|
1333
|
+
device const float * src0,
|
|
1334
|
+
device float * dst,
|
|
1335
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1336
|
+
device const float & x = src0[tpig];
|
|
1337
|
+
dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f);
|
|
1338
|
+
}
|
|
1339
|
+
|
|
1340
|
+
kernel void kernel_step(
|
|
1341
|
+
device const float * src0,
|
|
1342
|
+
device float * dst,
|
|
1343
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1344
|
+
dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f;
|
|
1345
|
+
}
|
|
1346
|
+
|
|
1347
|
+
kernel void kernel_hardswish(
|
|
1348
|
+
device const float * src0,
|
|
1349
|
+
device float * dst,
|
|
1350
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1351
|
+
device const float & x = src0[tpig];
|
|
1352
|
+
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
|
1353
|
+
}
|
|
1354
|
+
|
|
1355
|
+
kernel void kernel_hardsigmoid(
|
|
1356
|
+
device const float * src0,
|
|
1357
|
+
device float * dst,
|
|
1358
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1359
|
+
device const float & x = src0[tpig];
|
|
1360
|
+
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
|
1361
|
+
}
|
|
1362
|
+
|
|
1363
|
+
kernel void kernel_exp(
|
|
1364
|
+
device const float * src0,
|
|
1365
|
+
device float * dst,
|
|
1366
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
1367
|
+
dst[tpig] = exp(src0[tpig]);
|
|
1368
|
+
}
|
|
1369
|
+
|
|
1370
|
+
kernel void kernel_reglu(
|
|
1371
|
+
device const char * src0,
|
|
1372
|
+
device const char * src1,
|
|
1373
|
+
device char * dst,
|
|
1374
|
+
constant ggml_metal_kargs_glu & args,
|
|
1375
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1376
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1377
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1378
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1379
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1380
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1381
|
+
|
|
1382
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1383
|
+
const float x0 = src0_row[i0];
|
|
1384
|
+
const float x1 = src1_row[i0];
|
|
1385
|
+
|
|
1386
|
+
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
|
1387
|
+
}
|
|
1388
|
+
}
|
|
1389
|
+
|
|
1390
|
+
kernel void kernel_geglu(
|
|
1391
|
+
device const char * src0,
|
|
1392
|
+
device const char * src1,
|
|
1393
|
+
device char * dst,
|
|
1394
|
+
constant ggml_metal_kargs_glu & args,
|
|
1395
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1396
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1397
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1398
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1399
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1400
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1401
|
+
|
|
1402
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1403
|
+
const float x0 = src0_row[i0];
|
|
1404
|
+
const float x1 = src1_row[i0];
|
|
1405
|
+
|
|
1406
|
+
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
|
1407
|
+
|
|
1408
|
+
dst_row[i0] = gelu*x1;
|
|
1409
|
+
}
|
|
1410
|
+
}
|
|
1411
|
+
|
|
1412
|
+
kernel void kernel_swiglu(
|
|
1413
|
+
device const char * src0,
|
|
1414
|
+
device const char * src1,
|
|
1415
|
+
device char * dst,
|
|
1416
|
+
constant ggml_metal_kargs_glu & args,
|
|
1417
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1418
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1419
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1420
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1421
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1422
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1423
|
+
|
|
1424
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1425
|
+
const float x0 = src0_row[i0];
|
|
1426
|
+
const float x1 = src1_row[i0];
|
|
1427
|
+
|
|
1428
|
+
const float silu = x0 / (1.0f + exp(-x0));
|
|
1429
|
+
|
|
1430
|
+
dst_row[i0] = silu*x1;
|
|
1431
|
+
}
|
|
1432
|
+
}
|
|
1433
|
+
|
|
1434
|
+
kernel void kernel_geglu_erf(
|
|
1435
|
+
device const char * src0,
|
|
1436
|
+
device const char * src1,
|
|
1437
|
+
device char * dst,
|
|
1438
|
+
constant ggml_metal_kargs_glu & args,
|
|
1439
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1440
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1441
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1442
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1443
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1444
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1445
|
+
|
|
1446
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1447
|
+
const float x0 = src0_row[i0];
|
|
1448
|
+
const float x1 = src1_row[i0];
|
|
1449
|
+
|
|
1450
|
+
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
|
|
1451
|
+
|
|
1452
|
+
dst_row[i0] = gelu_erf*x1;
|
|
1453
|
+
}
|
|
1454
|
+
}
|
|
1455
|
+
|
|
1456
|
+
kernel void kernel_geglu_quick(
|
|
1457
|
+
device const char * src0,
|
|
1458
|
+
device const char * src1,
|
|
1459
|
+
device char * dst,
|
|
1460
|
+
constant ggml_metal_kargs_glu & args,
|
|
1461
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1462
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1463
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1464
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1465
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1466
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1467
|
+
|
|
1468
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1469
|
+
const float x0 = src0_row[i0];
|
|
1470
|
+
const float x1 = src1_row[i0];
|
|
1471
|
+
|
|
1472
|
+
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
|
|
1473
|
+
|
|
1474
|
+
dst_row[i0] = gelu_quick*x1;
|
|
1475
|
+
}
|
|
1476
|
+
}
|
|
1477
|
+
|
|
1194
1478
|
template <bool norm>
|
|
1195
1479
|
kernel void kernel_sum_rows(
|
|
1196
1480
|
constant ggml_metal_kargs_sum_rows & args,
|
|
@@ -1253,24 +1537,28 @@ kernel void kernel_soft_max(
|
|
|
1253
1537
|
device char * dst,
|
|
1254
1538
|
constant ggml_metal_kargs_soft_max & args,
|
|
1255
1539
|
threadgroup float * buf [[threadgroup(0)]],
|
|
1256
|
-
|
|
1257
|
-
|
|
1540
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1541
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1258
1542
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
1259
1543
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1260
|
-
|
|
1261
|
-
const
|
|
1262
|
-
const
|
|
1263
|
-
const
|
|
1544
|
+
uint3 tptg[[threads_per_threadgroup]]) {
|
|
1545
|
+
const int32_t i03 = tgpig.z;
|
|
1546
|
+
const int32_t i02 = tgpig.y;
|
|
1547
|
+
const int32_t i01 = tgpig.x;
|
|
1548
|
+
|
|
1549
|
+
const int32_t i13 = i03%args.ne13;
|
|
1550
|
+
const int32_t i12 = i02%args.ne12;
|
|
1551
|
+
const int32_t i11 = i01;
|
|
1264
1552
|
|
|
1265
|
-
device const float * psrc0 =
|
|
1266
|
-
device const T * pmask = src1 != src0 ? (device const
|
|
1267
|
-
device float * pdst =
|
|
1553
|
+
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
1554
|
+
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
|
1555
|
+
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
1268
1556
|
|
|
1269
1557
|
float slope = 1.0f;
|
|
1270
1558
|
|
|
1271
1559
|
// ALiBi
|
|
1272
1560
|
if (args.max_bias > 0.0f) {
|
|
1273
|
-
const
|
|
1561
|
+
const int32_t h = i02;
|
|
1274
1562
|
|
|
1275
1563
|
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
|
1276
1564
|
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
|
@@ -1281,13 +1569,13 @@ kernel void kernel_soft_max(
|
|
|
1281
1569
|
// parallel max
|
|
1282
1570
|
float lmax = -INFINITY;
|
|
1283
1571
|
|
|
1284
|
-
for (int i00 = tpitg; i00 < args.ne00; i00 +=
|
|
1572
|
+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
1285
1573
|
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
|
1286
1574
|
}
|
|
1287
1575
|
|
|
1288
1576
|
// find the max value in the block
|
|
1289
1577
|
float max_val = simd_max(lmax);
|
|
1290
|
-
if (
|
|
1578
|
+
if (tptg.x > N_SIMDWIDTH) {
|
|
1291
1579
|
if (sgitg == 0) {
|
|
1292
1580
|
buf[tiisg] = -INFINITY;
|
|
1293
1581
|
}
|
|
@@ -1306,7 +1594,7 @@ kernel void kernel_soft_max(
|
|
|
1306
1594
|
|
|
1307
1595
|
// parallel sum
|
|
1308
1596
|
float lsum = 0.0f;
|
|
1309
|
-
for (int i00 = tpitg; i00 < args.ne00; i00 +=
|
|
1597
|
+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
1310
1598
|
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
|
1311
1599
|
lsum += exp_psrc0;
|
|
1312
1600
|
pdst[i00] = exp_psrc0;
|
|
@@ -1318,7 +1606,7 @@ kernel void kernel_soft_max(
|
|
|
1318
1606
|
|
|
1319
1607
|
float sum = simd_sum(lsum);
|
|
1320
1608
|
|
|
1321
|
-
if (
|
|
1609
|
+
if (tptg.x > N_SIMDWIDTH) {
|
|
1322
1610
|
if (sgitg == 0) {
|
|
1323
1611
|
buf[tiisg] = 0.0f;
|
|
1324
1612
|
}
|
|
@@ -1337,7 +1625,7 @@ kernel void kernel_soft_max(
|
|
|
1337
1625
|
|
|
1338
1626
|
const float inv_sum = 1.0f/sum;
|
|
1339
1627
|
|
|
1340
|
-
for (int i00 = tpitg; i00 < args.ne00; i00 +=
|
|
1628
|
+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
1341
1629
|
pdst[i00] *= inv_sum;
|
|
1342
1630
|
}
|
|
1343
1631
|
}
|
|
@@ -1349,23 +1637,27 @@ kernel void kernel_soft_max_4(
|
|
|
1349
1637
|
device char * dst,
|
|
1350
1638
|
constant ggml_metal_kargs_soft_max & args,
|
|
1351
1639
|
threadgroup float * buf [[threadgroup(0)]],
|
|
1352
|
-
|
|
1353
|
-
|
|
1640
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1641
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1354
1642
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
1355
1643
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1356
|
-
|
|
1357
|
-
const
|
|
1358
|
-
const
|
|
1359
|
-
const
|
|
1644
|
+
uint3 tptg[[threads_per_threadgroup]]) {
|
|
1645
|
+
const int32_t i03 = tgpig.z;
|
|
1646
|
+
const int32_t i02 = tgpig.y;
|
|
1647
|
+
const int32_t i01 = tgpig.x;
|
|
1648
|
+
|
|
1649
|
+
const int32_t i13 = i03%args.ne13;
|
|
1650
|
+
const int32_t i12 = i02%args.ne12;
|
|
1651
|
+
const int32_t i11 = i01;
|
|
1360
1652
|
|
|
1361
|
-
device const float4 * psrc4 =
|
|
1362
|
-
device const T * pmask = src1 != src0 ? (device const
|
|
1363
|
-
device float4 * pdst4 =
|
|
1653
|
+
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
1654
|
+
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
|
1655
|
+
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
1364
1656
|
|
|
1365
1657
|
float slope = 1.0f;
|
|
1366
1658
|
|
|
1367
1659
|
if (args.max_bias > 0.0f) {
|
|
1368
|
-
const
|
|
1660
|
+
const int32_t h = i02;
|
|
1369
1661
|
|
|
1370
1662
|
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
|
1371
1663
|
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
|
@@ -1376,14 +1668,14 @@ kernel void kernel_soft_max_4(
|
|
|
1376
1668
|
// parallel max
|
|
1377
1669
|
float4 lmax4 = -INFINITY;
|
|
1378
1670
|
|
|
1379
|
-
for (int i00 = tpitg; i00 < args.ne00/4; i00 +=
|
|
1671
|
+
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
|
1380
1672
|
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
|
1381
1673
|
}
|
|
1382
1674
|
|
|
1383
1675
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
1384
1676
|
|
|
1385
1677
|
float max_val = simd_max(lmax);
|
|
1386
|
-
if (
|
|
1678
|
+
if (tptg.x > N_SIMDWIDTH) {
|
|
1387
1679
|
if (sgitg == 0) {
|
|
1388
1680
|
buf[tiisg] = -INFINITY;
|
|
1389
1681
|
}
|
|
@@ -1402,7 +1694,7 @@ kernel void kernel_soft_max_4(
|
|
|
1402
1694
|
|
|
1403
1695
|
// parallel sum
|
|
1404
1696
|
float4 lsum4 = 0.0f;
|
|
1405
|
-
for (int i00 = tpitg; i00 < args.ne00/4; i00 +=
|
|
1697
|
+
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
|
1406
1698
|
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
|
1407
1699
|
lsum4 += exp_psrc4;
|
|
1408
1700
|
pdst4[i00] = exp_psrc4;
|
|
@@ -1416,7 +1708,7 @@ kernel void kernel_soft_max_4(
|
|
|
1416
1708
|
|
|
1417
1709
|
float sum = simd_sum(lsum);
|
|
1418
1710
|
|
|
1419
|
-
if (
|
|
1711
|
+
if (tptg.x > N_SIMDWIDTH) {
|
|
1420
1712
|
if (sgitg == 0) {
|
|
1421
1713
|
buf[tiisg] = 0.0f;
|
|
1422
1714
|
}
|
|
@@ -1435,7 +1727,7 @@ kernel void kernel_soft_max_4(
|
|
|
1435
1727
|
|
|
1436
1728
|
const float inv_sum = 1.0f/sum;
|
|
1437
1729
|
|
|
1438
|
-
for (int i00 = tpitg; i00 < args.ne00/4; i00 +=
|
|
1730
|
+
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
|
1439
1731
|
pdst4[i00] *= inv_sum;
|
|
1440
1732
|
}
|
|
1441
1733
|
}
|
|
@@ -1521,7 +1813,7 @@ kernel void kernel_ssm_conv_f32(
|
|
|
1521
1813
|
x[0] = sumf;
|
|
1522
1814
|
}
|
|
1523
1815
|
|
|
1524
|
-
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
|
|
1816
|
+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
|
|
1525
1817
|
kernel void kernel_ssm_scan_f32(
|
|
1526
1818
|
device const void * src0,
|
|
1527
1819
|
device const void * src1,
|
|
@@ -1529,47 +1821,222 @@ kernel void kernel_ssm_scan_f32(
|
|
|
1529
1821
|
device const void * src3,
|
|
1530
1822
|
device const void * src4,
|
|
1531
1823
|
device const void * src5,
|
|
1824
|
+
device const void * src6,
|
|
1532
1825
|
device float * dst,
|
|
1826
|
+
threadgroup float * shared [[threadgroup(0)]],
|
|
1533
1827
|
constant ggml_metal_kargs_ssm_scan & args,
|
|
1534
|
-
uint3
|
|
1535
|
-
uint3
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1828
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1829
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1830
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
1831
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1832
|
+
ushort sgptg[[simdgroups_per_threadgroup]],
|
|
1833
|
+
uint3 tgpg[[threadgroups_per_grid]]) {
|
|
1834
|
+
|
|
1835
|
+
const int64_t i0 = tpitg.x;
|
|
1836
|
+
const int64_t i1 = 0;
|
|
1837
|
+
const int64_t ir = tgpig.x; // current head
|
|
1838
|
+
const int64_t i3 = tgpig.y; // current seq
|
|
1839
|
+
|
|
1840
|
+
const uint64_t nb00 = sizeof(float);
|
|
1841
|
+
const uint64_t nb10 = sizeof(float);
|
|
1842
|
+
const uint64_t nb20 = sizeof(float);
|
|
1539
1843
|
|
|
1540
1844
|
const int64_t nc = args.d_state;
|
|
1541
|
-
|
|
1845
|
+
const int64_t nr = args.d_inner;
|
|
1846
|
+
const int64_t nh = args.n_head;
|
|
1847
|
+
const int64_t ng = args.n_group;
|
|
1542
1848
|
const int64_t n_t = args.n_seq_tokens;
|
|
1543
|
-
|
|
1849
|
+
|
|
1850
|
+
const int64_t s_off = args.s_off;
|
|
1851
|
+
|
|
1852
|
+
device const int32_t * ids = (device const int32_t *) src6;
|
|
1853
|
+
|
|
1854
|
+
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
|
1855
|
+
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
|
1856
|
+
const int64_t i = i0 + i1*nc;
|
|
1857
|
+
float s0 = s0_buff[i];
|
|
1858
|
+
float s = s_buff[i];
|
|
1859
|
+
|
|
1860
|
+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
|
|
1861
|
+
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
|
1862
|
+
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
|
1863
|
+
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
|
1864
|
+
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
|
1865
|
+
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
|
1544
1866
|
|
|
1545
1867
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
|
1546
|
-
device const float *
|
|
1547
|
-
device const float *
|
|
1548
|
-
device const float *
|
|
1549
|
-
device const float *
|
|
1550
|
-
device
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
|
|
1868
|
+
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
|
|
1869
|
+
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
|
|
1870
|
+
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
|
|
1871
|
+
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
|
|
1872
|
+
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
|
|
1873
|
+
|
|
1874
|
+
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
1875
|
+
const float x_dt = x[0] * dt_soft_plus;
|
|
1876
|
+
|
|
1877
|
+
const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
|
1878
|
+
s = state;
|
|
1879
|
+
|
|
1880
|
+
// Parallel sum: This relies on the fact that this kernel will be
|
|
1881
|
+
// dispatched with each threadgroup having (d_state, 1, 1) threads which
|
|
1882
|
+
// are subdivided into SIMD groups of size `sgptg`. The goal is to
|
|
1883
|
+
// compute y = sum({state * C[i] for i in range(d_state)}).
|
|
1884
|
+
// To parallelize this effectively, we first use simd_sum over each SIMD
|
|
1885
|
+
// group to compute the sum of each SIMD group, then place the result in
|
|
1886
|
+
// the SIMD group's indexed bucket in the shared memory. We then sum
|
|
1887
|
+
// over the individual group sums to compute the final sum.
|
|
1554
1888
|
|
|
1555
|
-
|
|
1556
|
-
|
|
1889
|
+
// Computed for each thread
|
|
1890
|
+
float sumf = state * C[i0];
|
|
1891
|
+
|
|
1892
|
+
// Sum the threads in the simd group => simd sum
|
|
1893
|
+
sumf = simd_sum(sumf);
|
|
1894
|
+
|
|
1895
|
+
if (sgptg > 1) {
|
|
1896
|
+
|
|
1897
|
+
// Once per simd group, place the group sum into the shared buffer
|
|
1898
|
+
if (tiisg == 0) {
|
|
1899
|
+
shared[sgitg] = sumf;
|
|
1900
|
+
}
|
|
1901
|
+
|
|
1902
|
+
// Wait for all threads in the threadgroup to reach this point. This
|
|
1903
|
+
// ensures that all elements of the shared buffer are populated with the
|
|
1904
|
+
// sum of the individual simd groups.
|
|
1905
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1906
|
+
|
|
1907
|
+
// For simd group 0 at indices < num simd groups, extract the shared
|
|
1908
|
+
// simd sum
|
|
1909
|
+
sumf = 0.0f;
|
|
1910
|
+
if (sgitg == 0) {
|
|
1911
|
+
if (tiisg < sgptg) {
|
|
1912
|
+
sumf = shared[tiisg];
|
|
1913
|
+
}
|
|
1914
|
+
sumf = simd_sum(sumf);
|
|
1915
|
+
if (tiisg == 0) {
|
|
1916
|
+
y[0] = sumf;
|
|
1917
|
+
}
|
|
1918
|
+
}
|
|
1919
|
+
} else if (tiisg == 0) {
|
|
1920
|
+
y[0] = sumf;
|
|
1557
1921
|
}
|
|
1558
1922
|
|
|
1559
|
-
//
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1923
|
+
// recurse
|
|
1924
|
+
s0 = s;
|
|
1925
|
+
}
|
|
1926
|
+
|
|
1927
|
+
// Assign the final state to the output buffer
|
|
1928
|
+
s_buff[i] = s;
|
|
1929
|
+
}
|
|
1563
1930
|
|
|
1564
|
-
|
|
1565
|
-
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1931
|
+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
|
1932
|
+
kernel void kernel_ssm_scan_f32_group(
|
|
1933
|
+
device const void * src0,
|
|
1934
|
+
device const void * src1,
|
|
1935
|
+
device const void * src2,
|
|
1936
|
+
device const void * src3,
|
|
1937
|
+
device const void * src4,
|
|
1938
|
+
device const void * src5,
|
|
1939
|
+
device const void * src6,
|
|
1940
|
+
device float * dst,
|
|
1941
|
+
threadgroup float * shared [[threadgroup(0)]],
|
|
1942
|
+
constant ggml_metal_kargs_ssm_scan & args,
|
|
1943
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1944
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1945
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
1946
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1947
|
+
ushort sgptg[[simdgroups_per_threadgroup]],
|
|
1948
|
+
uint3 tgpg[[threadgroups_per_grid]]) {
|
|
1949
|
+
|
|
1950
|
+
const int64_t i0 = tpitg.x;
|
|
1951
|
+
const int64_t i1 = tgpig.x;
|
|
1952
|
+
const int64_t ir = tgpig.y; // current head
|
|
1953
|
+
const int64_t i3 = tgpig.z; // current seq
|
|
1954
|
+
|
|
1955
|
+
const uint64_t nb00 = sizeof(float);
|
|
1956
|
+
const uint64_t nb10 = sizeof(float);
|
|
1957
|
+
const uint64_t nb20 = sizeof(float);
|
|
1958
|
+
|
|
1959
|
+
const int64_t nc = args.d_state;
|
|
1960
|
+
const int64_t nr = args.d_inner;
|
|
1961
|
+
const int64_t nh = args.n_head;
|
|
1962
|
+
const int64_t ng = args.n_group;
|
|
1963
|
+
const int64_t n_t = args.n_seq_tokens;
|
|
1964
|
+
|
|
1965
|
+
const int64_t s_off = args.s_off;
|
|
1966
|
+
|
|
1967
|
+
device const int32_t * ids = (device const int32_t *) src6;
|
|
1968
|
+
|
|
1969
|
+
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
|
1970
|
+
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
|
1971
|
+
const int64_t i = i0 + i1*nc;
|
|
1972
|
+
float s0 = s0_buff[i];
|
|
1973
|
+
float s = s_buff[i];
|
|
1974
|
+
|
|
1975
|
+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
|
1976
|
+
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
|
1977
|
+
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
|
1978
|
+
device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43);
|
|
1979
|
+
device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53);
|
|
1980
|
+
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
|
1981
|
+
|
|
1982
|
+
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
|
1983
|
+
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
|
|
1984
|
+
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
|
|
1985
|
+
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
|
|
1986
|
+
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
|
|
1987
|
+
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
|
|
1988
|
+
|
|
1989
|
+
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
1990
|
+
const float x_dt = x[0] * dt_soft_plus;
|
|
1991
|
+
const float dA = exp(dt_soft_plus * A[0]);
|
|
1992
|
+
|
|
1993
|
+
const float state = (s0 * dA) + (B[i0] * x_dt);
|
|
1994
|
+
s = state;
|
|
1995
|
+
|
|
1996
|
+
// Parallel sum: This relies on the fact that this kernel will be
|
|
1997
|
+
// dispatched with each threadgroup having (d_state, 1, 1) threads which
|
|
1998
|
+
// are subdivided into SIMD groups of size `sgptg`. The goal is to
|
|
1999
|
+
// compute y = sum({state * C[i] for i in range(d_state)}).
|
|
2000
|
+
// To parallelize this effectively, we first use simd_sum over each SIMD
|
|
2001
|
+
// group to compute the sum of each SIMD group, then place the result in
|
|
2002
|
+
// the SIMD group's indexed bucket in the shared memory. We then sum
|
|
2003
|
+
// over the individual group sums to compute the final sum.
|
|
2004
|
+
|
|
2005
|
+
// Computed for each thread
|
|
2006
|
+
float sumf = state * C[i0];
|
|
2007
|
+
|
|
2008
|
+
// Sum the threads in the simd group => simd sum
|
|
2009
|
+
sumf = simd_sum(sumf);
|
|
2010
|
+
|
|
2011
|
+
// Once per simd group, place the group sum into the shared buffer
|
|
2012
|
+
if (tiisg == 0) {
|
|
2013
|
+
shared[sgitg] = sumf;
|
|
1569
2014
|
}
|
|
1570
2015
|
|
|
1571
|
-
|
|
2016
|
+
// Wait for all threads in the threadgroup to reach this point. This
|
|
2017
|
+
// ensures that all elements of the shared buffer are populated with the
|
|
2018
|
+
// sum of the individual simd groups.
|
|
2019
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2020
|
+
|
|
2021
|
+
// For simd group 0 at indices < num simd groups, extract the shared
|
|
2022
|
+
// simd sum
|
|
2023
|
+
sumf = 0.0f;
|
|
2024
|
+
if (sgitg == 0) {
|
|
2025
|
+
if (tiisg < sgptg) {
|
|
2026
|
+
sumf = shared[tiisg];
|
|
2027
|
+
}
|
|
2028
|
+
sumf = simd_sum(sumf);
|
|
2029
|
+
if (tiisg == 0) {
|
|
2030
|
+
y[0] = sumf;
|
|
2031
|
+
}
|
|
2032
|
+
}
|
|
2033
|
+
|
|
2034
|
+
// recurse
|
|
2035
|
+
s0 = s;
|
|
1572
2036
|
}
|
|
2037
|
+
|
|
2038
|
+
// Assign the final state to the output buffer
|
|
2039
|
+
s_buff[i] = s;
|
|
1573
2040
|
}
|
|
1574
2041
|
|
|
1575
2042
|
kernel void kernel_rwkv_wkv6_f32(
|
|
@@ -1874,26 +2341,39 @@ kernel void kernel_norm(
|
|
|
1874
2341
|
}
|
|
1875
2342
|
}
|
|
1876
2343
|
|
|
1877
|
-
|
|
2344
|
+
// F == 1 : rms_norm (no fuse)
|
|
2345
|
+
// F == 2 : rms_norm + mul
|
|
2346
|
+
// F == 3 : rms_norm + mul + add
|
|
2347
|
+
template <short F>
|
|
2348
|
+
kernel void kernel_rms_norm_fuse_impl(
|
|
1878
2349
|
constant ggml_metal_kargs_rms_norm & args,
|
|
1879
2350
|
device const char * src0,
|
|
2351
|
+
device const char * src1_0,
|
|
2352
|
+
device const char * src1_1,
|
|
1880
2353
|
device char * dst,
|
|
1881
2354
|
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
|
1882
|
-
|
|
1883
|
-
|
|
1884
|
-
ushort
|
|
1885
|
-
ushort
|
|
1886
|
-
|
|
2355
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2356
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
2357
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
2358
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
2359
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1887
2360
|
if (sgitg == 0) {
|
|
1888
2361
|
shmem_f32[tiisg] = 0.0f;
|
|
1889
2362
|
}
|
|
1890
2363
|
|
|
1891
|
-
|
|
2364
|
+
const int i01 = tgpig.x;
|
|
2365
|
+
const int i02 = tgpig.y;
|
|
2366
|
+
const int i03 = tgpig.z;
|
|
2367
|
+
|
|
2368
|
+
device const float4 * x = (device const float4 *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
|
|
2369
|
+
|
|
2370
|
+
device const float4 * f0 = (device const float4 *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
|
|
2371
|
+
device const float4 * f1 = (device const float4 *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
|
|
1892
2372
|
|
|
1893
2373
|
float sumf = 0.0f;
|
|
1894
2374
|
|
|
1895
2375
|
// parallel sum
|
|
1896
|
-
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
|
2376
|
+
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
|
|
1897
2377
|
sumf += dot(x[i00], x[i00]);
|
|
1898
2378
|
}
|
|
1899
2379
|
sumf = simd_sum(sumf);
|
|
@@ -1912,12 +2392,26 @@ kernel void kernel_rms_norm(
|
|
|
1912
2392
|
const float mean = sumf/args.ne00;
|
|
1913
2393
|
const float scale = 1.0f/sqrt(mean + args.eps);
|
|
1914
2394
|
|
|
1915
|
-
device float4 * y = (device float4 *) dst +
|
|
1916
|
-
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
|
1917
|
-
|
|
2395
|
+
device float4 * y = (device float4 *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
|
2396
|
+
for (int i00 = tpitg.x; i00 < args.ne00_4; i00 += ntg.x) {
|
|
2397
|
+
if (F == 1) {
|
|
2398
|
+
y[i00] = (x[i00]*scale);
|
|
2399
|
+
}
|
|
2400
|
+
if (F == 2) {
|
|
2401
|
+
y[i00] = (x[i00]*scale)*f0[i00];
|
|
2402
|
+
}
|
|
2403
|
+
if (F == 3) {
|
|
2404
|
+
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
|
|
2405
|
+
}
|
|
1918
2406
|
}
|
|
1919
2407
|
}
|
|
1920
2408
|
|
|
2409
|
+
typedef decltype(kernel_rms_norm_fuse_impl<1>) kernel_rms_norm_fuse_t;
|
|
2410
|
+
|
|
2411
|
+
template [[host_name("kernel_rms_norm")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<1>;
|
|
2412
|
+
template [[host_name("kernel_rms_norm_mul")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<2>;
|
|
2413
|
+
template [[host_name("kernel_rms_norm_mul_add")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<3>;
|
|
2414
|
+
|
|
1921
2415
|
kernel void kernel_l2_norm(
|
|
1922
2416
|
constant ggml_metal_kargs_l2_norm & args,
|
|
1923
2417
|
device const char * src0,
|
|
@@ -3709,7 +4203,7 @@ kernel void kernel_flash_attn_ext(
|
|
|
3709
4203
|
// load the mask in shared memory
|
|
3710
4204
|
#pragma unroll(Q)
|
|
3711
4205
|
for (short j = 0; j < Q; ++j) {
|
|
3712
|
-
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
|
|
4206
|
+
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
|
3713
4207
|
|
|
3714
4208
|
const float m = pm[ic + tiisg];
|
|
3715
4209
|
|
|
@@ -4195,7 +4689,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
4195
4689
|
const bool has_mask = mask != q;
|
|
4196
4690
|
|
|
4197
4691
|
// pointer to the mask
|
|
4198
|
-
device const half * pm = (device const half *) (mask + iq1*args.nb31);
|
|
4692
|
+
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
|
4199
4693
|
|
|
4200
4694
|
float slope = 1.0f;
|
|
4201
4695
|
|