@novastera-oss/llamarn 0.2.7 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +8 -3
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +56 -22
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +1 -2
- package/cpp/llama.cpp/README.md +4 -5
- package/cpp/llama.cpp/build-xcframework.sh +1 -1
- package/cpp/llama.cpp/common/CMakeLists.txt +4 -5
- package/cpp/llama.cpp/common/arg.cpp +24 -0
- package/cpp/llama.cpp/common/chat.cpp +37 -20
- package/cpp/llama.cpp/common/chat.h +2 -0
- package/cpp/llama.cpp/common/common.cpp +3 -0
- package/cpp/llama.cpp/common/common.h +5 -0
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/convert_hf_to_gguf.py +860 -23
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +9 -0
- package/cpp/llama.cpp/ggml/CMakeLists.txt +8 -2
- package/cpp/llama.cpp/ggml/include/ggml-backend.h +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +206 -10
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +17 -1
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +36 -18
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +68 -5
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +16 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +37 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +10 -9
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +109 -108
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +1027 -1038
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +53 -52
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +56 -55
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +42 -41
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +24 -23
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +29 -28
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +30 -29
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +83 -82
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +20 -19
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +9 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +111 -103
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +1405 -240
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +25 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +56 -40
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +212 -34
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +35 -11
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +187 -54
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +71 -29
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +22 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +8 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +14 -12
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +5 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +15 -10
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +12 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/getrows.cu +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +269 -110
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +2 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/rope.cu +21 -27
- package/cpp/llama.cpp/ggml/src/ggml-cuda/scale.cu +8 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/softmax.cu +119 -58
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +192 -52
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +97 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/upscale.cu +92 -6
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +14 -5
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +125 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -2
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +51 -9
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +394 -80
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +616 -239
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +741 -571
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/glu.cl +337 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/set_rows.cl +95 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f16.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/softmax_f32.cl +24 -11
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +6 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +99 -166
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +94 -72
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +697 -1098
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +18 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +6 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +104 -62
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +132 -201
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +39 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +131 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.hpp +8 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +767 -292
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +58 -7
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +28 -23
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +14 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +38 -32
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +32 -27
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +44 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +128 -72
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +38 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +20 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +23 -3
- package/cpp/llama.cpp/ggml/src/ggml.c +449 -72
- package/cpp/llama.cpp/ggml/src/gguf.cpp +13 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +285 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +27 -0
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +137 -21
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +109 -7
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +8 -43
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +265 -3
- package/cpp/llama.cpp/src/llama-arch.h +36 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +596 -359
- package/cpp/llama.cpp/src/llama-batch.h +105 -70
- package/cpp/llama.cpp/src/llama-chat.cpp +26 -6
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +101 -107
- package/cpp/llama.cpp/src/llama-context.h +13 -13
- package/cpp/llama.cpp/src/llama-graph.cpp +286 -404
- package/cpp/llama.cpp/src/llama-graph.h +78 -79
- package/cpp/llama.cpp/src/llama-hparams.cpp +11 -1
- package/cpp/llama.cpp/src/llama-hparams.h +11 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +74 -66
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +312 -157
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +79 -46
- package/cpp/llama.cpp/src/llama-kv-cells.h +97 -21
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +73 -69
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +19 -22
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +88 -77
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +15 -20
- package/cpp/llama.cpp/src/llama-memory.cpp +17 -0
- package/cpp/llama.cpp/src/llama-memory.h +21 -22
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +5301 -2922
- package/cpp/llama.cpp/src/llama-model.h +40 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +88 -5
- package/cpp/llama.cpp/src/llama-vocab.cpp +37 -3
- package/cpp/llama.cpp/src/llama-vocab.h +42 -0
- package/cpp/rn-utils.h +3 -0
- package/ios/include/chat.h +2 -0
- package/ios/include/common.h +5 -0
- package/ios/include/llama.h +8 -43
- package/ios/libs/llama.xcframework/Info.plist +19 -19
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3891 -3744
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5059 -4863
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5030 -4834
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3889 -3742
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5095 -4900
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5066 -4871
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3919 -3773
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-backend.h +1 -1
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +206 -10
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +8 -43
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -1
- package/cpp/llama.cpp/ggml/include/ggml-kompute.h +0 -50
- package/cpp/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
- package/cpp/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
- package/cpp/llama.cpp/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
|
@@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
|
|
|
35
35
|
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
|
36
36
|
};
|
|
37
37
|
|
|
38
|
+
static inline int best_index_int8(int n, constant float * val, float x) {
|
|
39
|
+
if (x <= val[0]) return 0;
|
|
40
|
+
if (x >= val[n-1]) return n-1;
|
|
41
|
+
int ml = 0, mu = n-1;
|
|
42
|
+
while (mu-ml > 1) {
|
|
43
|
+
int mav = (ml+mu)/2;
|
|
44
|
+
if (x < val[mav]) mu = mav; else ml = mav;
|
|
45
|
+
}
|
|
46
|
+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
|
47
|
+
}
|
|
48
|
+
|
|
38
49
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
39
50
|
template <typename type4x4>
|
|
40
51
|
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
@@ -97,6 +108,178 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
|
|
97
108
|
}
|
|
98
109
|
}
|
|
99
110
|
|
|
111
|
+
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
|
112
|
+
#pragma METAL fp math_mode(safe)
|
|
113
|
+
float amax = 0.0f; // absolute max
|
|
114
|
+
float max = 0.0f;
|
|
115
|
+
|
|
116
|
+
for (int j = 0; j < QK4_0; j++) {
|
|
117
|
+
const float v = src[j];
|
|
118
|
+
if (amax < fabs(v)) {
|
|
119
|
+
amax = fabs(v);
|
|
120
|
+
max = v;
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
const float d = max / -8;
|
|
125
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
126
|
+
|
|
127
|
+
dst.d = d;
|
|
128
|
+
|
|
129
|
+
for (int j = 0; j < QK4_0/2; ++j) {
|
|
130
|
+
const float x0 = src[0 + j]*id;
|
|
131
|
+
const float x1 = src[QK4_0/2 + j]*id;
|
|
132
|
+
|
|
133
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
|
134
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
|
135
|
+
|
|
136
|
+
dst.qs[j] = xi0;
|
|
137
|
+
dst.qs[j] |= xi1 << 4;
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
|
142
|
+
#pragma METAL fp math_mode(safe)
|
|
143
|
+
float min = FLT_MAX;
|
|
144
|
+
float max = -FLT_MAX;
|
|
145
|
+
|
|
146
|
+
for (int j = 0; j < QK4_1; j++) {
|
|
147
|
+
const float v = src[j];
|
|
148
|
+
if (min > v) min = v;
|
|
149
|
+
if (max < v) max = v;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
|
153
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
154
|
+
|
|
155
|
+
dst.d = d;
|
|
156
|
+
dst.m = min;
|
|
157
|
+
|
|
158
|
+
for (int j = 0; j < QK4_1/2; ++j) {
|
|
159
|
+
const float x0 = (src[0 + j] - min)*id;
|
|
160
|
+
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
|
161
|
+
|
|
162
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
|
163
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
|
164
|
+
|
|
165
|
+
dst.qs[j] = xi0;
|
|
166
|
+
dst.qs[j] |= xi1 << 4;
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
|
171
|
+
#pragma METAL fp math_mode(safe)
|
|
172
|
+
float amax = 0.0f; // absolute max
|
|
173
|
+
float max = 0.0f;
|
|
174
|
+
|
|
175
|
+
for (int j = 0; j < QK5_0; j++) {
|
|
176
|
+
const float v = src[j];
|
|
177
|
+
if (amax < fabs(v)) {
|
|
178
|
+
amax = fabs(v);
|
|
179
|
+
max = v;
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
const float d = max / -16;
|
|
184
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
185
|
+
|
|
186
|
+
dst.d = d;
|
|
187
|
+
|
|
188
|
+
uint32_t qh = 0;
|
|
189
|
+
for (int j = 0; j < QK5_0/2; ++j) {
|
|
190
|
+
const float x0 = src[0 + j]*id;
|
|
191
|
+
const float x1 = src[QK5_0/2 + j]*id;
|
|
192
|
+
|
|
193
|
+
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
|
194
|
+
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
|
195
|
+
|
|
196
|
+
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
197
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
198
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
|
202
|
+
|
|
203
|
+
for (int j = 0; j < 4; ++j) {
|
|
204
|
+
dst.qh[j] = qh8[j];
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
|
209
|
+
#pragma METAL fp math_mode(safe)
|
|
210
|
+
float max = src[0];
|
|
211
|
+
float min = src[0];
|
|
212
|
+
|
|
213
|
+
for (int j = 1; j < QK5_1; j++) {
|
|
214
|
+
const float v = src[j];
|
|
215
|
+
min = v < min ? v : min;
|
|
216
|
+
max = v > max ? v : max;
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
const float d = (max - min) / 31;
|
|
220
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
221
|
+
|
|
222
|
+
dst.d = d;
|
|
223
|
+
dst.m = min;
|
|
224
|
+
|
|
225
|
+
uint32_t qh = 0;
|
|
226
|
+
for (int j = 0; j < QK5_1/2; ++j) {
|
|
227
|
+
const float x0 = (src[0 + j] - min)*id;
|
|
228
|
+
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
|
229
|
+
|
|
230
|
+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
|
231
|
+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
|
232
|
+
|
|
233
|
+
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
234
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
235
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
|
239
|
+
|
|
240
|
+
for (int j = 0; j < 4; ++j) {
|
|
241
|
+
dst.qh[j] = qh8[j];
|
|
242
|
+
}
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
|
246
|
+
#pragma METAL fp math_mode(safe)
|
|
247
|
+
float amax = 0.0f; // absolute max
|
|
248
|
+
float max = 0.0f;
|
|
249
|
+
|
|
250
|
+
for (int j = 0; j < QK4_NL; j++) {
|
|
251
|
+
const float v = src[j];
|
|
252
|
+
if (amax < fabs(v)) {
|
|
253
|
+
amax = fabs(v);
|
|
254
|
+
max = v;
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
const float d = max / kvalues_iq4nl_f[0];
|
|
259
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
260
|
+
|
|
261
|
+
float sumqx = 0, sumq2 = 0;
|
|
262
|
+
for (int j = 0; j < QK4_NL/2; ++j) {
|
|
263
|
+
const float x0 = src[0 + j]*id;
|
|
264
|
+
const float x1 = src[QK4_NL/2 + j]*id;
|
|
265
|
+
|
|
266
|
+
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
|
267
|
+
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
|
268
|
+
|
|
269
|
+
dst.qs[j] = xi0 | (xi1 << 4);
|
|
270
|
+
|
|
271
|
+
const float v0 = kvalues_iq4nl_f[xi0];
|
|
272
|
+
const float v1 = kvalues_iq4nl_f[xi1];
|
|
273
|
+
const float w0 = src[0 + j]*src[0 + j];
|
|
274
|
+
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
|
275
|
+
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
|
276
|
+
sumq2 += w0*v0*v0 + w1*v1*v1;
|
|
277
|
+
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
|
|
281
|
+
}
|
|
282
|
+
|
|
100
283
|
template <typename type4x4>
|
|
101
284
|
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
|
102
285
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
|
@@ -279,6 +462,27 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
|
|
|
279
462
|
}
|
|
280
463
|
}
|
|
281
464
|
|
|
465
|
+
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
|
466
|
+
#pragma METAL fp math_mode(safe)
|
|
467
|
+
float amax = 0.0f; // absolute max
|
|
468
|
+
|
|
469
|
+
for (int j = 0; j < QK8_0; j++) {
|
|
470
|
+
const float v = src[j];
|
|
471
|
+
amax = MAX(amax, fabs(v));
|
|
472
|
+
}
|
|
473
|
+
|
|
474
|
+
const float d = amax / ((1 << 7) - 1);
|
|
475
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
476
|
+
|
|
477
|
+
dst.d = d;
|
|
478
|
+
|
|
479
|
+
for (int j = 0; j < QK8_0; ++j) {
|
|
480
|
+
const float x0 = src[j]*id;
|
|
481
|
+
|
|
482
|
+
dst.qs[j] = round(x0);
|
|
483
|
+
}
|
|
484
|
+
}
|
|
485
|
+
|
|
282
486
|
template <typename type4x4>
|
|
283
487
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
|
284
488
|
const float d = xb->d;
|
|
@@ -810,16 +1014,18 @@ kernel void kernel_scale(
|
|
|
810
1014
|
device const float * src0,
|
|
811
1015
|
device float * dst,
|
|
812
1016
|
constant float & scale,
|
|
1017
|
+
constant float & bias,
|
|
813
1018
|
uint tpig[[thread_position_in_grid]]) {
|
|
814
|
-
dst[tpig] = src0[tpig] * scale;
|
|
1019
|
+
dst[tpig] = src0[tpig] * scale + bias;
|
|
815
1020
|
}
|
|
816
1021
|
|
|
817
1022
|
kernel void kernel_scale_4(
|
|
818
1023
|
device const float4 * src0,
|
|
819
1024
|
device float4 * dst,
|
|
820
1025
|
constant float & scale,
|
|
1026
|
+
constant float & bias,
|
|
821
1027
|
uint tpig[[thread_position_in_grid]]) {
|
|
822
|
-
dst[tpig] = src0[tpig] * scale;
|
|
1028
|
+
dst[tpig] = src0[tpig] * scale + bias;
|
|
823
1029
|
}
|
|
824
1030
|
|
|
825
1031
|
kernel void kernel_clamp(
|
|
@@ -993,6 +1199,114 @@ kernel void kernel_neg(
|
|
|
993
1199
|
dst[tpig] = -src0[tpig];
|
|
994
1200
|
}
|
|
995
1201
|
|
|
1202
|
+
kernel void kernel_reglu(
|
|
1203
|
+
device const char * src0,
|
|
1204
|
+
device const char * src1,
|
|
1205
|
+
device char * dst,
|
|
1206
|
+
constant ggml_metal_kargs_glu & args,
|
|
1207
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1208
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1209
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1210
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1211
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1212
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1213
|
+
|
|
1214
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1215
|
+
const float x0 = src0_row[i0];
|
|
1216
|
+
const float x1 = src1_row[i0];
|
|
1217
|
+
|
|
1218
|
+
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
|
1219
|
+
}
|
|
1220
|
+
}
|
|
1221
|
+
|
|
1222
|
+
kernel void kernel_geglu(
|
|
1223
|
+
device const char * src0,
|
|
1224
|
+
device const char * src1,
|
|
1225
|
+
device char * dst,
|
|
1226
|
+
constant ggml_metal_kargs_glu & args,
|
|
1227
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1228
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1229
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1230
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1231
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1232
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1233
|
+
|
|
1234
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1235
|
+
const float x0 = src0_row[i0];
|
|
1236
|
+
const float x1 = src1_row[i0];
|
|
1237
|
+
|
|
1238
|
+
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
|
1239
|
+
|
|
1240
|
+
dst_row[i0] = gelu*x1;
|
|
1241
|
+
}
|
|
1242
|
+
}
|
|
1243
|
+
|
|
1244
|
+
kernel void kernel_swiglu(
|
|
1245
|
+
device const char * src0,
|
|
1246
|
+
device const char * src1,
|
|
1247
|
+
device char * dst,
|
|
1248
|
+
constant ggml_metal_kargs_glu & args,
|
|
1249
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1250
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1251
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1252
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1253
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1254
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1255
|
+
|
|
1256
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1257
|
+
const float x0 = src0_row[i0];
|
|
1258
|
+
const float x1 = src1_row[i0];
|
|
1259
|
+
|
|
1260
|
+
const float silu = x0 / (1.0f + exp(-x0));
|
|
1261
|
+
|
|
1262
|
+
dst_row[i0] = silu*x1;
|
|
1263
|
+
}
|
|
1264
|
+
}
|
|
1265
|
+
|
|
1266
|
+
kernel void kernel_geglu_erf(
|
|
1267
|
+
device const char * src0,
|
|
1268
|
+
device const char * src1,
|
|
1269
|
+
device char * dst,
|
|
1270
|
+
constant ggml_metal_kargs_glu & args,
|
|
1271
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1272
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1273
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1274
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1275
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1276
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1277
|
+
|
|
1278
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1279
|
+
const float x0 = src0_row[i0];
|
|
1280
|
+
const float x1 = src1_row[i0];
|
|
1281
|
+
|
|
1282
|
+
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
|
|
1283
|
+
|
|
1284
|
+
dst_row[i0] = gelu_erf*x1;
|
|
1285
|
+
}
|
|
1286
|
+
}
|
|
1287
|
+
|
|
1288
|
+
kernel void kernel_geglu_quick(
|
|
1289
|
+
device const char * src0,
|
|
1290
|
+
device const char * src1,
|
|
1291
|
+
device char * dst,
|
|
1292
|
+
constant ggml_metal_kargs_glu & args,
|
|
1293
|
+
uint tgpig[[threadgroup_position_in_grid]],
|
|
1294
|
+
uint tpitg[[thread_position_in_threadgroup]],
|
|
1295
|
+
uint ntg[[threads_per_threadgroup]]) {
|
|
1296
|
+
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
|
1297
|
+
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
|
1298
|
+
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
|
1299
|
+
|
|
1300
|
+
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
|
1301
|
+
const float x0 = src0_row[i0];
|
|
1302
|
+
const float x1 = src1_row[i0];
|
|
1303
|
+
|
|
1304
|
+
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
|
|
1305
|
+
|
|
1306
|
+
dst_row[i0] = gelu_quick*x1;
|
|
1307
|
+
}
|
|
1308
|
+
}
|
|
1309
|
+
|
|
996
1310
|
template <bool norm>
|
|
997
1311
|
kernel void kernel_sum_rows(
|
|
998
1312
|
constant ggml_metal_kargs_sum_rows & args,
|
|
@@ -1055,24 +1369,28 @@ kernel void kernel_soft_max(
|
|
|
1055
1369
|
device char * dst,
|
|
1056
1370
|
constant ggml_metal_kargs_soft_max & args,
|
|
1057
1371
|
threadgroup float * buf [[threadgroup(0)]],
|
|
1058
|
-
|
|
1059
|
-
|
|
1372
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1373
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1060
1374
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
1061
1375
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1062
|
-
|
|
1063
|
-
const
|
|
1064
|
-
const
|
|
1065
|
-
const
|
|
1376
|
+
uint3 tptg[[threads_per_threadgroup]]) {
|
|
1377
|
+
const int32_t i03 = tgpig.z;
|
|
1378
|
+
const int32_t i02 = tgpig.y;
|
|
1379
|
+
const int32_t i01 = tgpig.x;
|
|
1380
|
+
|
|
1381
|
+
const int32_t i13 = i03%args.ne13;
|
|
1382
|
+
const int32_t i12 = i02%args.ne12;
|
|
1383
|
+
const int32_t i11 = i01;
|
|
1066
1384
|
|
|
1067
|
-
device const float * psrc0 =
|
|
1068
|
-
device const T * pmask = src1 != src0 ? (device const
|
|
1069
|
-
device float * pdst =
|
|
1385
|
+
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
1386
|
+
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
|
1387
|
+
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
1070
1388
|
|
|
1071
1389
|
float slope = 1.0f;
|
|
1072
1390
|
|
|
1073
1391
|
// ALiBi
|
|
1074
1392
|
if (args.max_bias > 0.0f) {
|
|
1075
|
-
const
|
|
1393
|
+
const int32_t h = i02;
|
|
1076
1394
|
|
|
1077
1395
|
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
|
1078
1396
|
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
|
@@ -1083,13 +1401,13 @@ kernel void kernel_soft_max(
|
|
|
1083
1401
|
// parallel max
|
|
1084
1402
|
float lmax = -INFINITY;
|
|
1085
1403
|
|
|
1086
|
-
for (int i00 = tpitg; i00 < args.ne00; i00 +=
|
|
1404
|
+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
1087
1405
|
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
|
1088
1406
|
}
|
|
1089
1407
|
|
|
1090
1408
|
// find the max value in the block
|
|
1091
1409
|
float max_val = simd_max(lmax);
|
|
1092
|
-
if (
|
|
1410
|
+
if (tptg.x > N_SIMDWIDTH) {
|
|
1093
1411
|
if (sgitg == 0) {
|
|
1094
1412
|
buf[tiisg] = -INFINITY;
|
|
1095
1413
|
}
|
|
@@ -1108,7 +1426,7 @@ kernel void kernel_soft_max(
|
|
|
1108
1426
|
|
|
1109
1427
|
// parallel sum
|
|
1110
1428
|
float lsum = 0.0f;
|
|
1111
|
-
for (int i00 = tpitg; i00 < args.ne00; i00 +=
|
|
1429
|
+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
1112
1430
|
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
|
1113
1431
|
lsum += exp_psrc0;
|
|
1114
1432
|
pdst[i00] = exp_psrc0;
|
|
@@ -1120,7 +1438,7 @@ kernel void kernel_soft_max(
|
|
|
1120
1438
|
|
|
1121
1439
|
float sum = simd_sum(lsum);
|
|
1122
1440
|
|
|
1123
|
-
if (
|
|
1441
|
+
if (tptg.x > N_SIMDWIDTH) {
|
|
1124
1442
|
if (sgitg == 0) {
|
|
1125
1443
|
buf[tiisg] = 0.0f;
|
|
1126
1444
|
}
|
|
@@ -1139,7 +1457,7 @@ kernel void kernel_soft_max(
|
|
|
1139
1457
|
|
|
1140
1458
|
const float inv_sum = 1.0f/sum;
|
|
1141
1459
|
|
|
1142
|
-
for (int i00 = tpitg; i00 < args.ne00; i00 +=
|
|
1460
|
+
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
1143
1461
|
pdst[i00] *= inv_sum;
|
|
1144
1462
|
}
|
|
1145
1463
|
}
|
|
@@ -1151,23 +1469,27 @@ kernel void kernel_soft_max_4(
|
|
|
1151
1469
|
device char * dst,
|
|
1152
1470
|
constant ggml_metal_kargs_soft_max & args,
|
|
1153
1471
|
threadgroup float * buf [[threadgroup(0)]],
|
|
1154
|
-
|
|
1155
|
-
|
|
1472
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1473
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1156
1474
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
1157
1475
|
uint tiisg[[thread_index_in_simdgroup]],
|
|
1158
|
-
|
|
1159
|
-
const
|
|
1160
|
-
const
|
|
1161
|
-
const
|
|
1476
|
+
uint3 tptg[[threads_per_threadgroup]]) {
|
|
1477
|
+
const int32_t i03 = tgpig.z;
|
|
1478
|
+
const int32_t i02 = tgpig.y;
|
|
1479
|
+
const int32_t i01 = tgpig.x;
|
|
1162
1480
|
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1481
|
+
const int32_t i13 = i03%args.ne13;
|
|
1482
|
+
const int32_t i12 = i02%args.ne12;
|
|
1483
|
+
const int32_t i11 = i01;
|
|
1484
|
+
|
|
1485
|
+
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
1486
|
+
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
|
1487
|
+
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
1166
1488
|
|
|
1167
1489
|
float slope = 1.0f;
|
|
1168
1490
|
|
|
1169
1491
|
if (args.max_bias > 0.0f) {
|
|
1170
|
-
const
|
|
1492
|
+
const int32_t h = i02;
|
|
1171
1493
|
|
|
1172
1494
|
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
|
1173
1495
|
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
|
@@ -1178,14 +1500,14 @@ kernel void kernel_soft_max_4(
|
|
|
1178
1500
|
// parallel max
|
|
1179
1501
|
float4 lmax4 = -INFINITY;
|
|
1180
1502
|
|
|
1181
|
-
for (int i00 = tpitg; i00 < args.ne00/4; i00 +=
|
|
1503
|
+
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
|
1182
1504
|
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
|
1183
1505
|
}
|
|
1184
1506
|
|
|
1185
1507
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
1186
1508
|
|
|
1187
1509
|
float max_val = simd_max(lmax);
|
|
1188
|
-
if (
|
|
1510
|
+
if (tptg.x > N_SIMDWIDTH) {
|
|
1189
1511
|
if (sgitg == 0) {
|
|
1190
1512
|
buf[tiisg] = -INFINITY;
|
|
1191
1513
|
}
|
|
@@ -1204,7 +1526,7 @@ kernel void kernel_soft_max_4(
|
|
|
1204
1526
|
|
|
1205
1527
|
// parallel sum
|
|
1206
1528
|
float4 lsum4 = 0.0f;
|
|
1207
|
-
for (int i00 = tpitg; i00 < args.ne00/4; i00 +=
|
|
1529
|
+
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
|
1208
1530
|
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
|
1209
1531
|
lsum4 += exp_psrc4;
|
|
1210
1532
|
pdst4[i00] = exp_psrc4;
|
|
@@ -1218,7 +1540,7 @@ kernel void kernel_soft_max_4(
|
|
|
1218
1540
|
|
|
1219
1541
|
float sum = simd_sum(lsum);
|
|
1220
1542
|
|
|
1221
|
-
if (
|
|
1543
|
+
if (tptg.x > N_SIMDWIDTH) {
|
|
1222
1544
|
if (sgitg == 0) {
|
|
1223
1545
|
buf[tiisg] = 0.0f;
|
|
1224
1546
|
}
|
|
@@ -1237,7 +1559,7 @@ kernel void kernel_soft_max_4(
|
|
|
1237
1559
|
|
|
1238
1560
|
const float inv_sum = 1.0f/sum;
|
|
1239
1561
|
|
|
1240
|
-
for (int i00 = tpitg; i00 < args.ne00/4; i00 +=
|
|
1562
|
+
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
|
1241
1563
|
pdst4[i00] *= inv_sum;
|
|
1242
1564
|
}
|
|
1243
1565
|
}
|
|
@@ -1323,7 +1645,7 @@ kernel void kernel_ssm_conv_f32(
|
|
|
1323
1645
|
x[0] = sumf;
|
|
1324
1646
|
}
|
|
1325
1647
|
|
|
1326
|
-
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
|
|
1648
|
+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
|
|
1327
1649
|
kernel void kernel_ssm_scan_f32(
|
|
1328
1650
|
device const void * src0,
|
|
1329
1651
|
device const void * src1,
|
|
@@ -1331,46 +1653,119 @@ kernel void kernel_ssm_scan_f32(
|
|
|
1331
1653
|
device const void * src3,
|
|
1332
1654
|
device const void * src4,
|
|
1333
1655
|
device const void * src5,
|
|
1656
|
+
device const void * src6,
|
|
1334
1657
|
device float * dst,
|
|
1335
1658
|
constant ggml_metal_kargs_ssm_scan & args,
|
|
1336
1659
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1337
1660
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1338
1661
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1339
|
-
const int64_t
|
|
1340
|
-
const int64_t
|
|
1662
|
+
const int64_t i1 = 0;
|
|
1663
|
+
const int64_t ir = tgpig.x; // current head
|
|
1664
|
+
const int64_t i3 = tgpig.y; // current seq
|
|
1665
|
+
|
|
1666
|
+
const uint64_t nb00 = sizeof(float);
|
|
1667
|
+
const uint64_t nb10 = sizeof(float);
|
|
1668
|
+
const uint64_t nb20 = sizeof(float);
|
|
1341
1669
|
|
|
1342
1670
|
const int64_t nc = args.d_state;
|
|
1343
|
-
|
|
1671
|
+
const int64_t nr = args.d_inner;
|
|
1672
|
+
const int64_t nh = args.n_head;
|
|
1673
|
+
const int64_t ng = args.n_group;
|
|
1344
1674
|
const int64_t n_t = args.n_seq_tokens;
|
|
1345
|
-
|
|
1675
|
+
|
|
1676
|
+
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
|
1677
|
+
|
|
1678
|
+
device const int32_t * ids = (device const int32_t *) src6;
|
|
1679
|
+
|
|
1680
|
+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
|
1681
|
+
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
|
1346
1682
|
|
|
1347
1683
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
|
1348
|
-
device const float *
|
|
1349
|
-
device const float *
|
|
1350
|
-
device const float *
|
|
1351
|
-
device const float *
|
|
1352
|
-
device const float *
|
|
1353
|
-
device
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
if (i2 > 0) {
|
|
1358
|
-
s0 = s;
|
|
1359
|
-
}
|
|
1360
|
-
|
|
1361
|
-
// i1 == 0
|
|
1362
|
-
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
1363
|
-
float x_dt = x[0] * dt_soft_plus;
|
|
1684
|
+
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
|
1685
|
+
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
|
1686
|
+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
|
|
1687
|
+
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
|
1688
|
+
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
|
1689
|
+
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
|
1690
|
+
|
|
1691
|
+
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
1692
|
+
const float x_dt = x[0] * dt_soft_plus;
|
|
1364
1693
|
float sumf = 0.0f;
|
|
1365
1694
|
|
|
1366
1695
|
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
|
1367
|
-
int64_t i = i0;
|
|
1368
|
-
float state = (s0[i] * exp(dt_soft_plus * A[
|
|
1696
|
+
const int64_t i = i0 + i1*nc;
|
|
1697
|
+
const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
|
1369
1698
|
sumf += state * C[i0];
|
|
1370
1699
|
s[i] = state;
|
|
1371
1700
|
}
|
|
1372
1701
|
|
|
1373
1702
|
y[0] = sumf;
|
|
1703
|
+
|
|
1704
|
+
// recurse
|
|
1705
|
+
s0 = s;
|
|
1706
|
+
}
|
|
1707
|
+
}
|
|
1708
|
+
|
|
1709
|
+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
|
1710
|
+
// TODO: optimize (e.g. by parallelizing over d_state)
|
|
1711
|
+
kernel void kernel_ssm_scan_f32_group(
|
|
1712
|
+
device const void * src0,
|
|
1713
|
+
device const void * src1,
|
|
1714
|
+
device const void * src2,
|
|
1715
|
+
device const void * src3,
|
|
1716
|
+
device const void * src4,
|
|
1717
|
+
device const void * src5,
|
|
1718
|
+
device const void * src6,
|
|
1719
|
+
device float * dst,
|
|
1720
|
+
constant ggml_metal_kargs_ssm_scan & args,
|
|
1721
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1722
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1723
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1724
|
+
const int64_t i1 = tgpig.x;
|
|
1725
|
+
const int64_t ir = tgpig.y; // current head
|
|
1726
|
+
const int64_t i3 = tgpig.z; // current seq
|
|
1727
|
+
|
|
1728
|
+
const uint64_t nb00 = sizeof(float);
|
|
1729
|
+
const uint64_t nb10 = sizeof(float);
|
|
1730
|
+
const uint64_t nb20 = sizeof(float);
|
|
1731
|
+
|
|
1732
|
+
const int64_t nc = args.d_state;
|
|
1733
|
+
const int64_t nr = args.d_inner;
|
|
1734
|
+
const int64_t nh = args.n_head;
|
|
1735
|
+
const int64_t ng = args.n_group;
|
|
1736
|
+
const int64_t n_t = args.n_seq_tokens;
|
|
1737
|
+
|
|
1738
|
+
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
|
|
1739
|
+
|
|
1740
|
+
device const int32_t * ids = (device const int32_t *) src6;
|
|
1741
|
+
|
|
1742
|
+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
|
1743
|
+
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
|
1744
|
+
|
|
1745
|
+
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
|
1746
|
+
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
|
|
1747
|
+
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
|
|
1748
|
+
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
|
1749
|
+
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
|
|
1750
|
+
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
|
|
1751
|
+
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
|
|
1752
|
+
|
|
1753
|
+
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
1754
|
+
const float x_dt = x[0] * dt_soft_plus;
|
|
1755
|
+
const float dA = exp(dt_soft_plus * A[0]);
|
|
1756
|
+
float sumf = 0.0f;
|
|
1757
|
+
|
|
1758
|
+
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
|
1759
|
+
const int64_t i = i0 + i1*nc;
|
|
1760
|
+
const float state = (s0[i] * dA) + (B[i0] * x_dt);
|
|
1761
|
+
sumf += state * C[i0];
|
|
1762
|
+
s[i] = state;
|
|
1763
|
+
}
|
|
1764
|
+
|
|
1765
|
+
y[0] = sumf;
|
|
1766
|
+
|
|
1767
|
+
// recurse
|
|
1768
|
+
s0 = s;
|
|
1374
1769
|
}
|
|
1375
1770
|
}
|
|
1376
1771
|
|
|
@@ -2532,6 +2927,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<
|
|
|
2532
2927
|
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
|
|
2533
2928
|
#endif
|
|
2534
2929
|
|
|
2930
|
+
template<typename T04, typename T14, typename args_t>
|
|
2931
|
+
void kernel_mul_mv_c4_impl(
|
|
2932
|
+
args_t args,
|
|
2933
|
+
device const char * src0,
|
|
2934
|
+
device const char * src1,
|
|
2935
|
+
device char * dst,
|
|
2936
|
+
uint3 tgpig,
|
|
2937
|
+
ushort tiisg) {
|
|
2938
|
+
const int r0 = tgpig.x*32 + tiisg;
|
|
2939
|
+
const int rb = tgpig.y*N_MV_T_T;
|
|
2940
|
+
const int im = tgpig.z;
|
|
2941
|
+
|
|
2942
|
+
if (r0 >= args.ne01) {
|
|
2943
|
+
return;
|
|
2944
|
+
}
|
|
2945
|
+
|
|
2946
|
+
const uint i12 = im%args.ne12;
|
|
2947
|
+
const uint i13 = im/args.ne12;
|
|
2948
|
+
|
|
2949
|
+
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
|
2950
|
+
|
|
2951
|
+
device const T04 * x = (device const T04 *) (src0 + offset0);
|
|
2952
|
+
|
|
2953
|
+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
|
|
2954
|
+
|
|
2955
|
+
for (int row = 0; row < N_MV_T_T; ++row) {
|
|
2956
|
+
int r1 = rb + row;
|
|
2957
|
+
if (r1 >= args.ne11) {
|
|
2958
|
+
break;
|
|
2959
|
+
}
|
|
2960
|
+
|
|
2961
|
+
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
2962
|
+
|
|
2963
|
+
device const T14 * y = (device const T14 *) (src1 + offset1);
|
|
2964
|
+
|
|
2965
|
+
dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
|
|
2966
|
+
}
|
|
2967
|
+
}
|
|
2968
|
+
|
|
2969
|
+
template<typename T04, typename T14>
|
|
2970
|
+
kernel void kernel_mul_mv_c4(
|
|
2971
|
+
constant ggml_metal_kargs_mul_mv & args,
|
|
2972
|
+
device const char * src0,
|
|
2973
|
+
device const char * src1,
|
|
2974
|
+
device char * dst,
|
|
2975
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2976
|
+
ushort tiisg[[thread_index_in_simdgroup]]) {
|
|
2977
|
+
kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
|
|
2978
|
+
args,
|
|
2979
|
+
src0,
|
|
2980
|
+
src1,
|
|
2981
|
+
dst,
|
|
2982
|
+
tgpig,
|
|
2983
|
+
tiisg);
|
|
2984
|
+
}
|
|
2985
|
+
|
|
2986
|
+
typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
|
|
2987
|
+
|
|
2988
|
+
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
|
|
2989
|
+
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
|
|
2990
|
+
#if defined(GGML_METAL_USE_BF16)
|
|
2991
|
+
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
|
|
2992
|
+
#endif
|
|
2993
|
+
|
|
2535
2994
|
template<typename T, typename T4>
|
|
2536
2995
|
kernel void kernel_mul_mv_1row(
|
|
2537
2996
|
constant ggml_metal_kargs_mul_mv & args,
|
|
@@ -3447,7 +3906,7 @@ kernel void kernel_flash_attn_ext(
|
|
|
3447
3906
|
// load the mask in shared memory
|
|
3448
3907
|
#pragma unroll(Q)
|
|
3449
3908
|
for (short j = 0; j < Q; ++j) {
|
|
3450
|
-
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
|
|
3909
|
+
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
|
3451
3910
|
|
|
3452
3911
|
const float m = pm[ic + tiisg];
|
|
3453
3912
|
|
|
@@ -3933,7 +4392,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
3933
4392
|
const bool has_mask = mask != q;
|
|
3934
4393
|
|
|
3935
4394
|
// pointer to the mask
|
|
3936
|
-
device const half * pm = (device const half *) (mask + iq1*args.nb31);
|
|
4395
|
+
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
|
3937
4396
|
|
|
3938
4397
|
float slope = 1.0f;
|
|
3939
4398
|
|
|
@@ -4306,11 +4765,16 @@ kernel void kernel_cpy(
|
|
|
4306
4765
|
device const char * src0,
|
|
4307
4766
|
device char * dst,
|
|
4308
4767
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4768
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4309
4769
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
4310
|
-
ushort3
|
|
4770
|
+
ushort3 tptg[[threads_per_threadgroup]]) {
|
|
4311
4771
|
const int i03 = tgpig[2];
|
|
4312
4772
|
const int i02 = tgpig[1];
|
|
4313
|
-
const int i01 = tgpig[0];
|
|
4773
|
+
const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
|
|
4774
|
+
|
|
4775
|
+
if (i01 >= args.ne01) {
|
|
4776
|
+
return;
|
|
4777
|
+
}
|
|
4314
4778
|
|
|
4315
4779
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
4316
4780
|
|
|
@@ -4321,7 +4785,7 @@ kernel void kernel_cpy(
|
|
|
4321
4785
|
|
|
4322
4786
|
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
4323
4787
|
|
|
4324
|
-
for (int64_t i00 =
|
|
4788
|
+
for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
4325
4789
|
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4326
4790
|
dst_data[i00] = (T1) src[0];
|
|
4327
4791
|
}
|
|
@@ -4341,6 +4805,7 @@ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bf
|
|
|
4341
4805
|
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
|
|
4342
4806
|
#endif
|
|
4343
4807
|
|
|
4808
|
+
// TODO: templetify these kernels
|
|
4344
4809
|
kernel void kernel_cpy_f32_q8_0(
|
|
4345
4810
|
constant ggml_metal_kargs_cpy & args,
|
|
4346
4811
|
device const char * src0,
|
|
@@ -4364,23 +4829,7 @@ kernel void kernel_cpy_f32_q8_0(
|
|
|
4364
4829
|
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
|
|
4365
4830
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4366
4831
|
|
|
4367
|
-
|
|
4368
|
-
|
|
4369
|
-
for (int j = 0; j < QK8_0; j++) {
|
|
4370
|
-
const float v = src[j];
|
|
4371
|
-
amax = MAX(amax, fabs(v));
|
|
4372
|
-
}
|
|
4373
|
-
|
|
4374
|
-
const float d = amax / ((1 << 7) - 1);
|
|
4375
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4376
|
-
|
|
4377
|
-
dst_data[i00/QK8_0].d = d;
|
|
4378
|
-
|
|
4379
|
-
for (int j = 0; j < QK8_0; ++j) {
|
|
4380
|
-
const float x0 = src[j]*id;
|
|
4381
|
-
|
|
4382
|
-
dst_data[i00/QK8_0].qs[j] = round(x0);
|
|
4383
|
-
}
|
|
4832
|
+
quantize_q8_0(src, dst_data[i00/QK8_0]);
|
|
4384
4833
|
}
|
|
4385
4834
|
}
|
|
4386
4835
|
|
|
@@ -4407,32 +4856,7 @@ kernel void kernel_cpy_f32_q4_0(
|
|
|
4407
4856
|
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
|
|
4408
4857
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4409
4858
|
|
|
4410
|
-
|
|
4411
|
-
float max = 0.0f;
|
|
4412
|
-
|
|
4413
|
-
for (int j = 0; j < QK4_0; j++) {
|
|
4414
|
-
const float v = src[j];
|
|
4415
|
-
if (amax < fabs(v)) {
|
|
4416
|
-
amax = fabs(v);
|
|
4417
|
-
max = v;
|
|
4418
|
-
}
|
|
4419
|
-
}
|
|
4420
|
-
|
|
4421
|
-
const float d = max / -8;
|
|
4422
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4423
|
-
|
|
4424
|
-
dst_data[i00/QK4_0].d = d;
|
|
4425
|
-
|
|
4426
|
-
for (int j = 0; j < QK4_0/2; ++j) {
|
|
4427
|
-
const float x0 = src[0 + j]*id;
|
|
4428
|
-
const float x1 = src[QK4_0/2 + j]*id;
|
|
4429
|
-
|
|
4430
|
-
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
|
4431
|
-
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
|
4432
|
-
|
|
4433
|
-
dst_data[i00/QK4_0].qs[j] = xi0;
|
|
4434
|
-
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
|
|
4435
|
-
}
|
|
4859
|
+
quantize_q4_0(src, dst_data[i00/QK4_0]);
|
|
4436
4860
|
}
|
|
4437
4861
|
}
|
|
4438
4862
|
|
|
@@ -4459,31 +4883,7 @@ kernel void kernel_cpy_f32_q4_1(
|
|
|
4459
4883
|
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
|
|
4460
4884
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4461
4885
|
|
|
4462
|
-
|
|
4463
|
-
float max = -FLT_MAX;
|
|
4464
|
-
|
|
4465
|
-
for (int j = 0; j < QK4_1; j++) {
|
|
4466
|
-
const float v = src[j];
|
|
4467
|
-
if (min > v) min = v;
|
|
4468
|
-
if (max < v) max = v;
|
|
4469
|
-
}
|
|
4470
|
-
|
|
4471
|
-
const float d = (max - min) / ((1 << 4) - 1);
|
|
4472
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4473
|
-
|
|
4474
|
-
dst_data[i00/QK4_1].d = d;
|
|
4475
|
-
dst_data[i00/QK4_1].m = min;
|
|
4476
|
-
|
|
4477
|
-
for (int j = 0; j < QK4_1/2; ++j) {
|
|
4478
|
-
const float x0 = (src[0 + j] - min)*id;
|
|
4479
|
-
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
|
4480
|
-
|
|
4481
|
-
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
|
4482
|
-
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
|
4483
|
-
|
|
4484
|
-
dst_data[i00/QK4_1].qs[j] = xi0;
|
|
4485
|
-
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
|
|
4486
|
-
}
|
|
4886
|
+
quantize_q4_1(src, dst_data[i00/QK4_1]);
|
|
4487
4887
|
}
|
|
4488
4888
|
}
|
|
4489
4889
|
|
|
@@ -4510,38 +4910,7 @@ kernel void kernel_cpy_f32_q5_0(
|
|
|
4510
4910
|
for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
|
|
4511
4911
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4512
4912
|
|
|
4513
|
-
|
|
4514
|
-
float max = 0.0f;
|
|
4515
|
-
|
|
4516
|
-
for (int j = 0; j < QK5_0; j++) {
|
|
4517
|
-
const float v = src[j];
|
|
4518
|
-
if (amax < fabs(v)) {
|
|
4519
|
-
amax = fabs(v);
|
|
4520
|
-
max = v;
|
|
4521
|
-
}
|
|
4522
|
-
}
|
|
4523
|
-
|
|
4524
|
-
const float d = max / -16;
|
|
4525
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4526
|
-
|
|
4527
|
-
dst_data[i00/QK5_0].d = d;
|
|
4528
|
-
|
|
4529
|
-
uint32_t qh = 0;
|
|
4530
|
-
for (int j = 0; j < QK5_0/2; ++j) {
|
|
4531
|
-
const float x0 = src[0 + j]*id;
|
|
4532
|
-
const float x1 = src[QK5_0/2 + j]*id;
|
|
4533
|
-
|
|
4534
|
-
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
|
4535
|
-
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
|
4536
|
-
|
|
4537
|
-
dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
4538
|
-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
4539
|
-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
|
4540
|
-
}
|
|
4541
|
-
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
|
4542
|
-
for (int j = 0; j < 4; ++j) {
|
|
4543
|
-
dst_data[i00/QK5_0].qh[j] = qh8[j];
|
|
4544
|
-
}
|
|
4913
|
+
quantize_q5_0(src, dst_data[i00/QK5_0]);
|
|
4545
4914
|
}
|
|
4546
4915
|
}
|
|
4547
4916
|
|
|
@@ -4568,51 +4937,10 @@ kernel void kernel_cpy_f32_q5_1(
|
|
|
4568
4937
|
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
|
|
4569
4938
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4570
4939
|
|
|
4571
|
-
|
|
4572
|
-
float min = src[0];
|
|
4573
|
-
|
|
4574
|
-
for (int j = 1; j < QK5_1; j++) {
|
|
4575
|
-
const float v = src[j];
|
|
4576
|
-
min = v < min ? v : min;
|
|
4577
|
-
max = v > max ? v : max;
|
|
4578
|
-
}
|
|
4579
|
-
|
|
4580
|
-
const float d = (max - min) / 31;
|
|
4581
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4582
|
-
|
|
4583
|
-
dst_data[i00/QK5_1].d = d;
|
|
4584
|
-
dst_data[i00/QK5_1].m = min;
|
|
4585
|
-
|
|
4586
|
-
uint32_t qh = 0;
|
|
4587
|
-
for (int j = 0; j < QK5_1/2; ++j) {
|
|
4588
|
-
const float x0 = (src[0 + j] - min)*id;
|
|
4589
|
-
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
|
4590
|
-
|
|
4591
|
-
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
|
4592
|
-
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
|
4593
|
-
|
|
4594
|
-
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
4595
|
-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
4596
|
-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
|
4597
|
-
}
|
|
4598
|
-
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
|
4599
|
-
for (int j = 0; j < 4; ++j) {
|
|
4600
|
-
dst_data[i00/QK5_1].qh[j] = qh8[j];
|
|
4601
|
-
}
|
|
4940
|
+
quantize_q5_1(src, dst_data[i00/QK5_1]);
|
|
4602
4941
|
}
|
|
4603
4942
|
}
|
|
4604
4943
|
|
|
4605
|
-
static inline int best_index_int8(int n, constant float * val, float x) {
|
|
4606
|
-
if (x <= val[0]) return 0;
|
|
4607
|
-
if (x >= val[n-1]) return n-1;
|
|
4608
|
-
int ml = 0, mu = n-1;
|
|
4609
|
-
while (mu-ml > 1) {
|
|
4610
|
-
int mav = (ml+mu)/2;
|
|
4611
|
-
if (x < val[mav]) mu = mav; else ml = mav;
|
|
4612
|
-
}
|
|
4613
|
-
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
|
4614
|
-
}
|
|
4615
|
-
|
|
4616
4944
|
kernel void kernel_cpy_f32_iq4_nl(
|
|
4617
4945
|
constant ggml_metal_kargs_cpy & args,
|
|
4618
4946
|
device const char * src0,
|
|
@@ -4636,40 +4964,7 @@ kernel void kernel_cpy_f32_iq4_nl(
|
|
|
4636
4964
|
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
|
|
4637
4965
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4638
4966
|
|
|
4639
|
-
|
|
4640
|
-
float max = 0.0f;
|
|
4641
|
-
|
|
4642
|
-
for (int j = 0; j < QK4_NL; j++) {
|
|
4643
|
-
const float v = src[j];
|
|
4644
|
-
if (amax < fabs(v)) {
|
|
4645
|
-
amax = fabs(v);
|
|
4646
|
-
max = v;
|
|
4647
|
-
}
|
|
4648
|
-
}
|
|
4649
|
-
|
|
4650
|
-
const float d = max / kvalues_iq4nl_f[0];
|
|
4651
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4652
|
-
|
|
4653
|
-
float sumqx = 0, sumq2 = 0;
|
|
4654
|
-
for (int j = 0; j < QK4_NL/2; ++j) {
|
|
4655
|
-
const float x0 = src[0 + j]*id;
|
|
4656
|
-
const float x1 = src[QK4_NL/2 + j]*id;
|
|
4657
|
-
|
|
4658
|
-
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
|
4659
|
-
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
|
4660
|
-
|
|
4661
|
-
dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
|
|
4662
|
-
|
|
4663
|
-
const float v0 = kvalues_iq4nl_f[xi0];
|
|
4664
|
-
const float v1 = kvalues_iq4nl_f[xi1];
|
|
4665
|
-
const float w0 = src[0 + j]*src[0 + j];
|
|
4666
|
-
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
|
4667
|
-
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
|
4668
|
-
sumq2 += w0*v0*v0 + w1*v1*v1;
|
|
4669
|
-
|
|
4670
|
-
}
|
|
4671
|
-
|
|
4672
|
-
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
|
|
4967
|
+
quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
|
|
4673
4968
|
}
|
|
4674
4969
|
}
|
|
4675
4970
|
|
|
@@ -6350,10 +6645,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
|
6350
6645
|
|
|
6351
6646
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
6352
6647
|
kernel void kernel_get_rows_q(
|
|
6648
|
+
constant ggml_metal_kargs_get_rows & args,
|
|
6353
6649
|
device const void * src0,
|
|
6354
6650
|
device const void * src1,
|
|
6355
6651
|
device float * dst,
|
|
6356
|
-
constant ggml_metal_kargs_get_rows & args,
|
|
6357
6652
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6358
6653
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
6359
6654
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
@@ -6373,10 +6668,10 @@ kernel void kernel_get_rows_q(
|
|
|
6373
6668
|
|
|
6374
6669
|
template<typename T>
|
|
6375
6670
|
kernel void kernel_get_rows_f(
|
|
6671
|
+
constant ggml_metal_kargs_get_rows & args,
|
|
6376
6672
|
device const void * src0,
|
|
6377
6673
|
device const void * src1,
|
|
6378
6674
|
device float * dst,
|
|
6379
|
-
constant ggml_metal_kargs_get_rows & args,
|
|
6380
6675
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6381
6676
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
6382
6677
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
@@ -6394,10 +6689,10 @@ kernel void kernel_get_rows_f(
|
|
|
6394
6689
|
}
|
|
6395
6690
|
|
|
6396
6691
|
kernel void kernel_get_rows_i32(
|
|
6692
|
+
constant ggml_metal_kargs_get_rows & args,
|
|
6397
6693
|
device const void * src0,
|
|
6398
6694
|
device const void * src1,
|
|
6399
6695
|
device int32_t * dst,
|
|
6400
|
-
constant ggml_metal_kargs_get_rows & args,
|
|
6401
6696
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6402
6697
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
6403
6698
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
@@ -6414,6 +6709,67 @@ kernel void kernel_get_rows_i32(
|
|
|
6414
6709
|
}
|
|
6415
6710
|
}
|
|
6416
6711
|
|
|
6712
|
+
template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
|
|
6713
|
+
kernel void kernel_set_rows_q32(
|
|
6714
|
+
constant ggml_metal_kargs_set_rows & args,
|
|
6715
|
+
device const void * src0,
|
|
6716
|
+
device const void * src1,
|
|
6717
|
+
device float * dst,
|
|
6718
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6719
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6720
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
6721
|
+
const int32_t i03 = tgpig.z;
|
|
6722
|
+
const int32_t i02 = tgpig.y;
|
|
6723
|
+
|
|
6724
|
+
const int32_t i12 = i03%args.ne12;
|
|
6725
|
+
const int32_t i11 = i02%args.ne11;
|
|
6726
|
+
|
|
6727
|
+
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
|
6728
|
+
if (i01 >= args.ne01) {
|
|
6729
|
+
return;
|
|
6730
|
+
}
|
|
6731
|
+
|
|
6732
|
+
const int32_t i10 = i01;
|
|
6733
|
+
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
|
6734
|
+
|
|
6735
|
+
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
6736
|
+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
6737
|
+
|
|
6738
|
+
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
|
6739
|
+
quantize_func(src_row + 32*ind, dst_row[ind]);
|
|
6740
|
+
}
|
|
6741
|
+
}
|
|
6742
|
+
|
|
6743
|
+
template<typename T>
|
|
6744
|
+
kernel void kernel_set_rows_f(
|
|
6745
|
+
constant ggml_metal_kargs_set_rows & args,
|
|
6746
|
+
device const void * src0,
|
|
6747
|
+
device const void * src1,
|
|
6748
|
+
device float * dst,
|
|
6749
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6750
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6751
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
6752
|
+
const int32_t i03 = tgpig.z;
|
|
6753
|
+
const int32_t i02 = tgpig.y;
|
|
6754
|
+
|
|
6755
|
+
const int32_t i12 = i03%args.ne12;
|
|
6756
|
+
const int32_t i11 = i02%args.ne11;
|
|
6757
|
+
|
|
6758
|
+
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
|
6759
|
+
if (i01 >= args.ne01) {
|
|
6760
|
+
return;
|
|
6761
|
+
}
|
|
6762
|
+
|
|
6763
|
+
const int32_t i10 = i01;
|
|
6764
|
+
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
|
6765
|
+
|
|
6766
|
+
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
6767
|
+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
6768
|
+
|
|
6769
|
+
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
|
6770
|
+
dst_row[ind] = (T) src_row[ind];
|
|
6771
|
+
}
|
|
6772
|
+
}
|
|
6417
6773
|
|
|
6418
6774
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
|
6419
6775
|
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
@@ -6837,6 +7193,27 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get
|
|
|
6837
7193
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
6838
7194
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
6839
7195
|
|
|
7196
|
+
//
|
|
7197
|
+
// set rows
|
|
7198
|
+
//
|
|
7199
|
+
|
|
7200
|
+
typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
|
|
7201
|
+
|
|
7202
|
+
template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
|
|
7203
|
+
template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
|
|
7204
|
+
#if defined(GGML_METAL_USE_BF16)
|
|
7205
|
+
template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
|
|
7206
|
+
#endif
|
|
7207
|
+
|
|
7208
|
+
typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
|
|
7209
|
+
|
|
7210
|
+
template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
|
|
7211
|
+
template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
|
|
7212
|
+
template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
|
|
7213
|
+
template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
|
|
7214
|
+
template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
|
|
7215
|
+
template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
|
|
7216
|
+
|
|
6840
7217
|
//
|
|
6841
7218
|
// matrix-matrix multiplication
|
|
6842
7219
|
//
|