@novastera-oss/llamarn 0.2.6 → 0.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +141 -38
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +58 -24
- package/cpp/LlamaCppModel.h +3 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +32 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
- package/cpp/llama.cpp/common/arg.cpp +37 -6
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
- package/cpp/llama.cpp/common/chat-parser.h +2 -0
- package/cpp/llama.cpp/common/chat.cpp +12 -9
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +53 -40
- package/cpp/llama.cpp/common/common.h +6 -2
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +3 -46
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +215 -76
- package/cpp/llama.cpp/ggml/CMakeLists.txt +48 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/include/ggml-cpu.h +2 -0
- package/cpp/llama.cpp/ggml/include/ggml.h +33 -0
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +64 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +124 -26
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +4 -3
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +93 -104
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +194 -69
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1158 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1571 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +213 -37
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.h +45 -45
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +59 -37
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +90 -39
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cu +19 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mean.cuh +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cu +257 -87
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmv.cuh +2 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cu +5 -18
- package/cpp/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +61 -183
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +260 -49
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +497 -282
- package/cpp/llama.cpp/ggml/src/ggml-musa/mudnn.cuh +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1078 -468
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +5 -6
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +20 -48
- package/cpp/llama.cpp/ggml/src/ggml-sycl/concat.cpp +28 -41
- package/cpp/llama.cpp/ggml/src/ggml-sycl/conv.cpp +4 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +117 -165
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +192 -53
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +49 -67
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +99 -159
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +8 -105
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +209 -92
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gla.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +60 -80
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +158 -203
- package/cpp/llama.cpp/ggml/src/ggml-sycl/norm.cpp +55 -74
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +24 -20
- package/cpp/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +12 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +36 -28
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +487 -247
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +69 -19
- package/cpp/llama.cpp/ggml/src/gguf.cpp +5 -1
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +133 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +25 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +78 -3
- package/cpp/llama.cpp/gguf-py/gguf/vocab.py +97 -4
- package/cpp/llama.cpp/gguf-py/pyproject.toml +2 -2
- package/cpp/llama.cpp/include/llama.h +141 -38
- package/cpp/llama.cpp/models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja +124 -0
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-arch.cpp +150 -3
- package/cpp/llama.cpp/src/llama-arch.h +25 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +736 -274
- package/cpp/llama.cpp/src/llama-batch.h +110 -57
- package/cpp/llama.cpp/src/llama-chat.cpp +30 -8
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +360 -266
- package/cpp/llama.cpp/src/llama-context.h +27 -23
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +411 -344
- package/cpp/llama.cpp/src/llama-graph.h +126 -58
- package/cpp/llama.cpp/src/llama-hparams.cpp +10 -2
- package/cpp/llama.cpp/src/llama-hparams.h +16 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +103 -73
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +34 -42
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +345 -221
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +75 -50
- package/cpp/llama.cpp/src/llama-kv-cells.h +51 -22
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +246 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +138 -0
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +302 -317
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +60 -68
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +73 -36
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model-saver.cpp +1 -0
- package/cpp/llama.cpp/src/llama-model.cpp +1630 -511
- package/cpp/llama.cpp/src/llama-model.h +26 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +89 -6
- package/cpp/llama.cpp/src/llama-vocab.cpp +58 -26
- package/cpp/llama.cpp/src/llama-vocab.h +1 -0
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/rn-completion.cpp +2 -2
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/cpp/{rn-utils.hpp → rn-utils.h} +3 -0
- package/ios/include/chat.h +1 -1
- package/ios/include/common.h +6 -2
- package/ios/include/llama.h +141 -38
- package/ios/libs/llama.xcframework/Info.plist +15 -15
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3766 -3624
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4890 -4689
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4861 -4710
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3764 -3622
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4926 -4725
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4897 -4746
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3794 -3652
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml-cpu.h +2 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +33 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +141 -38
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
|
@@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
|
|
|
35
35
|
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
|
36
36
|
};
|
|
37
37
|
|
|
38
|
+
static inline int best_index_int8(int n, constant float * val, float x) {
|
|
39
|
+
if (x <= val[0]) return 0;
|
|
40
|
+
if (x >= val[n-1]) return n-1;
|
|
41
|
+
int ml = 0, mu = n-1;
|
|
42
|
+
while (mu-ml > 1) {
|
|
43
|
+
int mav = (ml+mu)/2;
|
|
44
|
+
if (x < val[mav]) mu = mav; else ml = mav;
|
|
45
|
+
}
|
|
46
|
+
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
|
47
|
+
}
|
|
48
|
+
|
|
38
49
|
// NOTE: this is not dequantizing - we are simply fitting the template
|
|
39
50
|
template <typename type4x4>
|
|
40
51
|
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
|
@@ -97,6 +108,173 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
|
|
97
108
|
}
|
|
98
109
|
}
|
|
99
110
|
|
|
111
|
+
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
|
112
|
+
float amax = 0.0f; // absolute max
|
|
113
|
+
float max = 0.0f;
|
|
114
|
+
|
|
115
|
+
for (int j = 0; j < QK4_0; j++) {
|
|
116
|
+
const float v = src[j];
|
|
117
|
+
if (amax < fabs(v)) {
|
|
118
|
+
amax = fabs(v);
|
|
119
|
+
max = v;
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
const float d = max / -8;
|
|
124
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
125
|
+
|
|
126
|
+
dst.d = d;
|
|
127
|
+
|
|
128
|
+
for (int j = 0; j < QK4_0/2; ++j) {
|
|
129
|
+
const float x0 = src[0 + j]*id;
|
|
130
|
+
const float x1 = src[QK4_0/2 + j]*id;
|
|
131
|
+
|
|
132
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
|
133
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
|
134
|
+
|
|
135
|
+
dst.qs[j] = xi0;
|
|
136
|
+
dst.qs[j] |= xi1 << 4;
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
|
141
|
+
float min = FLT_MAX;
|
|
142
|
+
float max = -FLT_MAX;
|
|
143
|
+
|
|
144
|
+
for (int j = 0; j < QK4_1; j++) {
|
|
145
|
+
const float v = src[j];
|
|
146
|
+
if (min > v) min = v;
|
|
147
|
+
if (max < v) max = v;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
const float d = (max - min) / ((1 << 4) - 1);
|
|
151
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
152
|
+
|
|
153
|
+
dst.d = d;
|
|
154
|
+
dst.m = min;
|
|
155
|
+
|
|
156
|
+
for (int j = 0; j < QK4_1/2; ++j) {
|
|
157
|
+
const float x0 = (src[0 + j] - min)*id;
|
|
158
|
+
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
|
159
|
+
|
|
160
|
+
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
|
161
|
+
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
|
162
|
+
|
|
163
|
+
dst.qs[j] = xi0;
|
|
164
|
+
dst.qs[j] |= xi1 << 4;
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
|
169
|
+
float amax = 0.0f; // absolute max
|
|
170
|
+
float max = 0.0f;
|
|
171
|
+
|
|
172
|
+
for (int j = 0; j < QK5_0; j++) {
|
|
173
|
+
const float v = src[j];
|
|
174
|
+
if (amax < fabs(v)) {
|
|
175
|
+
amax = fabs(v);
|
|
176
|
+
max = v;
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
const float d = max / -16;
|
|
181
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
182
|
+
|
|
183
|
+
dst.d = d;
|
|
184
|
+
|
|
185
|
+
uint32_t qh = 0;
|
|
186
|
+
for (int j = 0; j < QK5_0/2; ++j) {
|
|
187
|
+
const float x0 = src[0 + j]*id;
|
|
188
|
+
const float x1 = src[QK5_0/2 + j]*id;
|
|
189
|
+
|
|
190
|
+
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
|
191
|
+
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
|
192
|
+
|
|
193
|
+
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
194
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
195
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
|
199
|
+
|
|
200
|
+
for (int j = 0; j < 4; ++j) {
|
|
201
|
+
dst.qh[j] = qh8[j];
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
|
206
|
+
float max = src[0];
|
|
207
|
+
float min = src[0];
|
|
208
|
+
|
|
209
|
+
for (int j = 1; j < QK5_1; j++) {
|
|
210
|
+
const float v = src[j];
|
|
211
|
+
min = v < min ? v : min;
|
|
212
|
+
max = v > max ? v : max;
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
const float d = (max - min) / 31;
|
|
216
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
217
|
+
|
|
218
|
+
dst.d = d;
|
|
219
|
+
dst.m = min;
|
|
220
|
+
|
|
221
|
+
uint32_t qh = 0;
|
|
222
|
+
for (int j = 0; j < QK5_1/2; ++j) {
|
|
223
|
+
const float x0 = (src[0 + j] - min)*id;
|
|
224
|
+
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
|
225
|
+
|
|
226
|
+
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
|
227
|
+
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
|
228
|
+
|
|
229
|
+
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
230
|
+
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
231
|
+
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
|
235
|
+
|
|
236
|
+
for (int j = 0; j < 4; ++j) {
|
|
237
|
+
dst.qh[j] = qh8[j];
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
|
242
|
+
float amax = 0.0f; // absolute max
|
|
243
|
+
float max = 0.0f;
|
|
244
|
+
|
|
245
|
+
for (int j = 0; j < QK4_NL; j++) {
|
|
246
|
+
const float v = src[j];
|
|
247
|
+
if (amax < fabs(v)) {
|
|
248
|
+
amax = fabs(v);
|
|
249
|
+
max = v;
|
|
250
|
+
}
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
const float d = max / kvalues_iq4nl_f[0];
|
|
254
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
255
|
+
|
|
256
|
+
float sumqx = 0, sumq2 = 0;
|
|
257
|
+
for (int j = 0; j < QK4_NL/2; ++j) {
|
|
258
|
+
const float x0 = src[0 + j]*id;
|
|
259
|
+
const float x1 = src[QK4_NL/2 + j]*id;
|
|
260
|
+
|
|
261
|
+
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
|
262
|
+
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
|
263
|
+
|
|
264
|
+
dst.qs[j] = xi0 | (xi1 << 4);
|
|
265
|
+
|
|
266
|
+
const float v0 = kvalues_iq4nl_f[xi0];
|
|
267
|
+
const float v1 = kvalues_iq4nl_f[xi1];
|
|
268
|
+
const float w0 = src[0 + j]*src[0 + j];
|
|
269
|
+
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
|
270
|
+
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
|
271
|
+
sumq2 += w0*v0*v0 + w1*v1*v1;
|
|
272
|
+
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
|
|
276
|
+
}
|
|
277
|
+
|
|
100
278
|
template <typename type4x4>
|
|
101
279
|
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
|
102
280
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
|
@@ -279,6 +457,26 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
|
|
|
279
457
|
}
|
|
280
458
|
}
|
|
281
459
|
|
|
460
|
+
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
|
461
|
+
float amax = 0.0f; // absolute max
|
|
462
|
+
|
|
463
|
+
for (int j = 0; j < QK8_0; j++) {
|
|
464
|
+
const float v = src[j];
|
|
465
|
+
amax = MAX(amax, fabs(v));
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
const float d = amax / ((1 << 7) - 1);
|
|
469
|
+
const float id = d ? 1.0f/d : 0.0f;
|
|
470
|
+
|
|
471
|
+
dst.d = d;
|
|
472
|
+
|
|
473
|
+
for (int j = 0; j < QK8_0; ++j) {
|
|
474
|
+
const float x0 = src[j]*id;
|
|
475
|
+
|
|
476
|
+
dst.qs[j] = round(x0);
|
|
477
|
+
}
|
|
478
|
+
}
|
|
479
|
+
|
|
282
480
|
template <typename type4x4>
|
|
283
481
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
|
284
482
|
const float d = xb->d;
|
|
@@ -993,31 +1191,61 @@ kernel void kernel_neg(
|
|
|
993
1191
|
dst[tpig] = -src0[tpig];
|
|
994
1192
|
}
|
|
995
1193
|
|
|
1194
|
+
template <bool norm>
|
|
996
1195
|
kernel void kernel_sum_rows(
|
|
1196
|
+
constant ggml_metal_kargs_sum_rows & args,
|
|
997
1197
|
device const float * src0,
|
|
998
1198
|
device float * dst,
|
|
999
|
-
|
|
1000
|
-
uint3
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1199
|
+
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
|
1200
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1201
|
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
1202
|
+
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
1203
|
+
ushort tiisg[[thread_index_in_simdgroup]],
|
|
1204
|
+
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
1205
|
+
int64_t i3 = tgpig.z;
|
|
1206
|
+
int64_t i2 = tgpig.y;
|
|
1207
|
+
int64_t i1 = tgpig.x;
|
|
1004
1208
|
|
|
1005
1209
|
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
|
1006
1210
|
return;
|
|
1007
1211
|
}
|
|
1008
1212
|
|
|
1213
|
+
if (sgitg == 0) {
|
|
1214
|
+
shmem_f32[tiisg] = 0.0f;
|
|
1215
|
+
}
|
|
1216
|
+
|
|
1009
1217
|
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
|
1010
1218
|
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
|
1011
1219
|
|
|
1012
|
-
float
|
|
1220
|
+
float sumf = 0;
|
|
1221
|
+
|
|
1222
|
+
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
|
1223
|
+
sumf += src_row[i0];
|
|
1224
|
+
}
|
|
1225
|
+
|
|
1226
|
+
sumf = simd_sum(sumf);
|
|
1013
1227
|
|
|
1014
|
-
|
|
1015
|
-
|
|
1228
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1229
|
+
|
|
1230
|
+
if (tiisg == 0) {
|
|
1231
|
+
shmem_f32[sgitg] = sumf;
|
|
1016
1232
|
}
|
|
1017
1233
|
|
|
1018
|
-
|
|
1234
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1235
|
+
|
|
1236
|
+
sumf = shmem_f32[tiisg];
|
|
1237
|
+
sumf = simd_sum(sumf);
|
|
1238
|
+
|
|
1239
|
+
if (tpitg.x == 0) {
|
|
1240
|
+
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
|
1241
|
+
}
|
|
1019
1242
|
}
|
|
1020
1243
|
|
|
1244
|
+
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
|
1245
|
+
|
|
1246
|
+
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
|
1247
|
+
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
|
1248
|
+
|
|
1021
1249
|
template<typename T>
|
|
1022
1250
|
kernel void kernel_soft_max(
|
|
1023
1251
|
device const char * src0,
|
|
@@ -2502,6 +2730,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<
|
|
|
2502
2730
|
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
|
|
2503
2731
|
#endif
|
|
2504
2732
|
|
|
2733
|
+
template<typename T04, typename T14, typename args_t>
|
|
2734
|
+
void kernel_mul_mv_c4_impl(
|
|
2735
|
+
args_t args,
|
|
2736
|
+
device const char * src0,
|
|
2737
|
+
device const char * src1,
|
|
2738
|
+
device char * dst,
|
|
2739
|
+
uint3 tgpig,
|
|
2740
|
+
ushort tiisg) {
|
|
2741
|
+
const int r0 = tgpig.x*32 + tiisg;
|
|
2742
|
+
const int rb = tgpig.y*N_MV_T_T;
|
|
2743
|
+
const int im = tgpig.z;
|
|
2744
|
+
|
|
2745
|
+
if (r0 >= args.ne01) {
|
|
2746
|
+
return;
|
|
2747
|
+
}
|
|
2748
|
+
|
|
2749
|
+
const uint i12 = im%args.ne12;
|
|
2750
|
+
const uint i13 = im/args.ne12;
|
|
2751
|
+
|
|
2752
|
+
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
|
2753
|
+
|
|
2754
|
+
device const T04 * x = (device const T04 *) (src0 + offset0);
|
|
2755
|
+
|
|
2756
|
+
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
|
|
2757
|
+
|
|
2758
|
+
for (int row = 0; row < N_MV_T_T; ++row) {
|
|
2759
|
+
int r1 = rb + row;
|
|
2760
|
+
if (r1 >= args.ne11) {
|
|
2761
|
+
break;
|
|
2762
|
+
}
|
|
2763
|
+
|
|
2764
|
+
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
|
2765
|
+
|
|
2766
|
+
device const T14 * y = (device const T14 *) (src1 + offset1);
|
|
2767
|
+
|
|
2768
|
+
dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
|
|
2769
|
+
}
|
|
2770
|
+
}
|
|
2771
|
+
|
|
2772
|
+
template<typename T04, typename T14>
|
|
2773
|
+
kernel void kernel_mul_mv_c4(
|
|
2774
|
+
constant ggml_metal_kargs_mul_mv & args,
|
|
2775
|
+
device const char * src0,
|
|
2776
|
+
device const char * src1,
|
|
2777
|
+
device char * dst,
|
|
2778
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2779
|
+
ushort tiisg[[thread_index_in_simdgroup]]) {
|
|
2780
|
+
kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
|
|
2781
|
+
args,
|
|
2782
|
+
src0,
|
|
2783
|
+
src1,
|
|
2784
|
+
dst,
|
|
2785
|
+
tgpig,
|
|
2786
|
+
tiisg);
|
|
2787
|
+
}
|
|
2788
|
+
|
|
2789
|
+
typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
|
|
2790
|
+
|
|
2791
|
+
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
|
|
2792
|
+
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
|
|
2793
|
+
#if defined(GGML_METAL_USE_BF16)
|
|
2794
|
+
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
|
|
2795
|
+
#endif
|
|
2796
|
+
|
|
2505
2797
|
template<typename T, typename T4>
|
|
2506
2798
|
kernel void kernel_mul_mv_1row(
|
|
2507
2799
|
constant ggml_metal_kargs_mul_mv & args,
|
|
@@ -3328,14 +3620,12 @@ kernel void kernel_flash_attn_ext(
|
|
|
3328
3620
|
constexpr short NW = N_SIMDWIDTH;
|
|
3329
3621
|
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
|
3330
3622
|
|
|
3331
|
-
const short TS = nsg*SH;
|
|
3332
|
-
const short T = DK + 2*TS; // shared memory size per query in (half)
|
|
3623
|
+
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
|
3624
|
+
const short T = 2*DK + 2*TS; // shared memory size per query in (half)
|
|
3333
3625
|
|
|
3334
|
-
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 +
|
|
3335
|
-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +
|
|
3336
|
-
threadgroup
|
|
3337
|
-
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
|
|
3338
|
-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
|
3626
|
+
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
|
3627
|
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
|
3628
|
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
|
3339
3629
|
|
|
3340
3630
|
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
|
3341
3631
|
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
|
@@ -3354,7 +3644,7 @@ kernel void kernel_flash_attn_ext(
|
|
|
3354
3644
|
if (iq1 + j < args.ne01) {
|
|
3355
3645
|
sq4[j*DK4 + i] = (q4_t) q4[i];
|
|
3356
3646
|
} else {
|
|
3357
|
-
sq4[j*DK4 + i] =
|
|
3647
|
+
sq4[j*DK4 + i] = 0;
|
|
3358
3648
|
}
|
|
3359
3649
|
}
|
|
3360
3650
|
}
|
|
@@ -3548,20 +3838,20 @@ kernel void kernel_flash_attn_ext(
|
|
|
3548
3838
|
|
|
3549
3839
|
// O = diag(ms)*O
|
|
3550
3840
|
{
|
|
3551
|
-
s8x8_t
|
|
3552
|
-
simdgroup_load(
|
|
3841
|
+
s8x8_t ms;
|
|
3842
|
+
simdgroup_load(ms, ss + 2*C, TS, 0, false);
|
|
3553
3843
|
|
|
3554
3844
|
#pragma unroll(DV8)
|
|
3555
3845
|
for (short i = 0; i < DV8; ++i) {
|
|
3556
|
-
simdgroup_multiply(lo[i],
|
|
3846
|
+
simdgroup_multiply(lo[i], ms, lo[i]);
|
|
3557
3847
|
}
|
|
3558
3848
|
}
|
|
3559
3849
|
|
|
3560
3850
|
// O = O + (Q*K^T)*V
|
|
3561
3851
|
{
|
|
3562
3852
|
for (short cc = 0; cc < C/8; ++cc) {
|
|
3563
|
-
s8x8_t
|
|
3564
|
-
simdgroup_load(
|
|
3853
|
+
s8x8_t vs;
|
|
3854
|
+
simdgroup_load(vs, ss + 8*cc, TS, 0, false);
|
|
3565
3855
|
|
|
3566
3856
|
if (is_same<vd4x4_t, v4x4_t>::value) {
|
|
3567
3857
|
// we can read directly from global memory
|
|
@@ -3572,7 +3862,7 @@ kernel void kernel_flash_attn_ext(
|
|
|
3572
3862
|
v8x8_t mv;
|
|
3573
3863
|
simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
|
|
3574
3864
|
|
|
3575
|
-
simdgroup_multiply_accumulate(lo[i],
|
|
3865
|
+
simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
|
|
3576
3866
|
}
|
|
3577
3867
|
} else {
|
|
3578
3868
|
for (short ii = 0; ii < DV16; ii += 4) {
|
|
@@ -3593,10 +3883,10 @@ kernel void kernel_flash_attn_ext(
|
|
|
3593
3883
|
v8x8_t mv;
|
|
3594
3884
|
|
|
3595
3885
|
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
|
3596
|
-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0],
|
|
3886
|
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
|
|
3597
3887
|
|
|
3598
3888
|
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
|
|
3599
|
-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1],
|
|
3889
|
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
|
|
3600
3890
|
}
|
|
3601
3891
|
} else {
|
|
3602
3892
|
if (ii + tx < DV16) {
|
|
@@ -3611,10 +3901,10 @@ kernel void kernel_flash_attn_ext(
|
|
|
3611
3901
|
v8x8_t mv;
|
|
3612
3902
|
|
|
3613
3903
|
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
|
3614
|
-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0],
|
|
3904
|
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
|
|
3615
3905
|
|
|
3616
3906
|
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
|
|
3617
|
-
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1],
|
|
3907
|
+
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
|
|
3618
3908
|
}
|
|
3619
3909
|
}
|
|
3620
3910
|
}
|
|
@@ -3624,93 +3914,89 @@ kernel void kernel_flash_attn_ext(
|
|
|
3624
3914
|
}
|
|
3625
3915
|
|
|
3626
3916
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
|
3627
|
-
for (short j =
|
|
3628
|
-
|
|
3629
|
-
|
|
3630
|
-
ss[j*TS + 1] = M[j];
|
|
3631
|
-
}
|
|
3917
|
+
for (short j = tiisg; j < Q; j += NW) {
|
|
3918
|
+
ss[j*TS + 0] = S[j];
|
|
3919
|
+
ss[j*TS + 1] = M[j];
|
|
3632
3920
|
}
|
|
3633
3921
|
}
|
|
3634
3922
|
|
|
3635
|
-
|
|
3636
|
-
for (ushort sg = 1; sg < nsg; ++sg) {
|
|
3637
|
-
float S = { 0.0f };
|
|
3638
|
-
float M = { -__FLT_MAX__/2 };
|
|
3923
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3639
3924
|
|
|
3640
|
-
|
|
3925
|
+
threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation
|
|
3926
|
+
threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
|
|
3641
3927
|
|
|
3642
|
-
|
|
3643
|
-
|
|
3644
|
-
|
|
3645
|
-
|
|
3646
|
-
|
|
3928
|
+
// store result to shared memory in F32
|
|
3929
|
+
if (sgitg == 0) {
|
|
3930
|
+
for (short i = 0; i < DV8; ++i) {
|
|
3931
|
+
//simdgroup_store(lo[i], so + i*8, DV, 0, false);
|
|
3932
|
+
simdgroup_float8x8 t(1.0f);
|
|
3933
|
+
simdgroup_multiply(t, lo[i], t);
|
|
3934
|
+
simdgroup_store(t, so + i*8, DV, 0, false);
|
|
3647
3935
|
}
|
|
3936
|
+
}
|
|
3648
3937
|
|
|
3649
|
-
|
|
3938
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3650
3939
|
|
|
3651
|
-
|
|
3652
|
-
|
|
3653
|
-
|
|
3654
|
-
|
|
3655
|
-
const float
|
|
3940
|
+
// reduce the warps sequentially
|
|
3941
|
+
for (ushort sg = 1; sg < nsg; ++sg) {
|
|
3942
|
+
if (sgitg == sg) {
|
|
3943
|
+
for (short j = tiisg; j < Q; j += NW) {
|
|
3944
|
+
const float S0 = ss[j*TS - 1*SH + 0];
|
|
3945
|
+
const float S1 = ss[j*TS + 0];
|
|
3656
3946
|
|
|
3657
|
-
const float M0 = ss[j*TS +
|
|
3658
|
-
const float M1 = ss[j*TS
|
|
3947
|
+
const float M0 = ss[j*TS - 1*SH + 1];
|
|
3948
|
+
const float M1 = ss[j*TS + 1];
|
|
3659
3949
|
|
|
3660
|
-
M = max(M0, M1);
|
|
3950
|
+
const float M = max(M0, M1);
|
|
3661
3951
|
|
|
3662
|
-
|
|
3663
|
-
|
|
3952
|
+
float ms0 = exp(M0 - M);
|
|
3953
|
+
float ms1 = exp(M1 - M);
|
|
3664
3954
|
|
|
3665
|
-
S = S0*ms0 + S1*ms1;
|
|
3955
|
+
const float S = S0*ms0 + S1*ms1;
|
|
3666
3956
|
|
|
3667
|
-
|
|
3668
|
-
|
|
3669
|
-
ss[j*TS + 1] = M;
|
|
3957
|
+
ss[j*TS + 0] = S;
|
|
3958
|
+
ss[j*TS + 1] = M;
|
|
3670
3959
|
|
|
3671
|
-
|
|
3672
|
-
|
|
3673
|
-
}
|
|
3960
|
+
ss[j*TS + 2*C + j - 1*SH] = ms0;
|
|
3961
|
+
ss[j*TS + 2*C + j ] = ms1;
|
|
3674
3962
|
}
|
|
3675
3963
|
|
|
3964
|
+
//simdgroup_barrier(mem_flags::mem_threadgroup);
|
|
3965
|
+
|
|
3676
3966
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
|
3677
3967
|
{
|
|
3678
3968
|
s8x8_t ms0;
|
|
3679
3969
|
s8x8_t ms1;
|
|
3680
3970
|
|
|
3681
|
-
simdgroup_load(ms0, ss + 2*C,
|
|
3682
|
-
simdgroup_load(ms1, ss + 2*C
|
|
3971
|
+
simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
|
|
3972
|
+
simdgroup_load(ms1, ss + 2*C, TS, 0, false);
|
|
3683
3973
|
|
|
3684
3974
|
#pragma unroll(DV8)
|
|
3685
3975
|
for (short i = 0; i < DV8; ++i) {
|
|
3686
|
-
|
|
3976
|
+
simdgroup_float8x8 t;
|
|
3687
3977
|
|
|
3688
3978
|
simdgroup_load (t, so + i*8, DV, 0, false);
|
|
3689
|
-
simdgroup_multiply(t,
|
|
3979
|
+
simdgroup_multiply(t, ms0, t);
|
|
3690
3980
|
|
|
3691
|
-
simdgroup_multiply_accumulate(
|
|
3981
|
+
simdgroup_multiply_accumulate(t, ms1, lo[i], t);
|
|
3982
|
+
simdgroup_store(t, so + i*8, DV, 0, false);
|
|
3692
3983
|
}
|
|
3693
3984
|
}
|
|
3694
3985
|
}
|
|
3695
|
-
}
|
|
3696
3986
|
|
|
3697
|
-
|
|
3698
|
-
if (sgitg == 0) {
|
|
3699
|
-
for (short i = 0; i < DV8; ++i) {
|
|
3700
|
-
simdgroup_store(lo[i], so + i*8, DV, 0, false);
|
|
3701
|
-
}
|
|
3987
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
3702
3988
|
}
|
|
3703
3989
|
|
|
3704
|
-
|
|
3990
|
+
threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
|
|
3705
3991
|
|
|
3706
3992
|
// final rescale with 1/S and store to global memory
|
|
3707
|
-
|
|
3708
|
-
|
|
3709
|
-
const float S = ss[j*TS + 0];
|
|
3993
|
+
for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
|
|
3994
|
+
const float S = 1.0f/sf[j*TS + 0];
|
|
3710
3995
|
|
|
3711
|
-
|
|
3712
|
-
|
|
3713
|
-
|
|
3996
|
+
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
|
|
3997
|
+
|
|
3998
|
+
for (short i = tiisg; i < DV4; i += NW) {
|
|
3999
|
+
dst4[i] = (float4) so4[j*DV4 + i]*S;
|
|
3714
4000
|
}
|
|
3715
4001
|
}
|
|
3716
4002
|
}
|
|
@@ -3719,12 +4005,22 @@ kernel void kernel_flash_attn_ext(
|
|
|
3719
4005
|
// template to be able to explore different combinations
|
|
3720
4006
|
//
|
|
3721
4007
|
#define FA_TYPES \
|
|
3722
|
-
|
|
3723
|
-
half,
|
|
3724
|
-
half,
|
|
3725
|
-
float,
|
|
3726
|
-
float,
|
|
3727
|
-
half,
|
|
4008
|
+
float, float4, simdgroup_float8x8, \
|
|
4009
|
+
half, half4x4, simdgroup_half8x8, \
|
|
4010
|
+
half, half4x4, simdgroup_half8x8, \
|
|
4011
|
+
float, simdgroup_float8x8, \
|
|
4012
|
+
float, simdgroup_float8x8, \
|
|
4013
|
+
half, half4, simdgroup_half8x8
|
|
4014
|
+
//float, float4, simdgroup_float8x8
|
|
4015
|
+
|
|
4016
|
+
#define FA_TYPES_BF \
|
|
4017
|
+
bfloat, bfloat4, simdgroup_bfloat8x8, \
|
|
4018
|
+
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
|
|
4019
|
+
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
|
|
4020
|
+
float, simdgroup_float8x8, \
|
|
4021
|
+
float, simdgroup_float8x8, \
|
|
4022
|
+
half, half4, simdgroup_half8x8
|
|
4023
|
+
//float, float4, simdgroup_float8x8
|
|
3728
4024
|
|
|
3729
4025
|
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
|
3730
4026
|
|
|
@@ -3739,15 +4035,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
|
|
|
3739
4035
|
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
|
3740
4036
|
|
|
3741
4037
|
#if defined(GGML_METAL_USE_BF16)
|
|
3742
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3743
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3744
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3745
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3746
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3747
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3748
|
-
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3749
|
-
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
3750
|
-
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<
|
|
4038
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
|
4039
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
|
4040
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
|
4041
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
|
4042
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
|
|
4043
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
|
4044
|
+
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
|
4045
|
+
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
|
4046
|
+
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
|
3751
4047
|
#endif
|
|
3752
4048
|
|
|
3753
4049
|
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
|
@@ -3801,6 +4097,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
|
|
|
3801
4097
|
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
|
3802
4098
|
|
|
3803
4099
|
#undef FA_TYPES
|
|
4100
|
+
#undef FA_TYPES_BF
|
|
3804
4101
|
|
|
3805
4102
|
template<
|
|
3806
4103
|
typename q4_t, // query types in shared memory
|
|
@@ -3847,12 +4144,12 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
3847
4144
|
|
|
3848
4145
|
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
|
3849
4146
|
|
|
3850
|
-
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 +
|
|
3851
|
-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +
|
|
3852
|
-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 +
|
|
3853
|
-
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 +
|
|
3854
|
-
threadgroup float * sm = (threadgroup float *) (shmem_f16 +
|
|
3855
|
-
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
|
|
4147
|
+
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
|
4148
|
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
|
4149
|
+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
|
4150
|
+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
|
4151
|
+
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
|
|
4152
|
+
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results
|
|
3856
4153
|
|
|
3857
4154
|
// store the result for all queries in local memory (the O matrix from the paper)
|
|
3858
4155
|
o4_t lo[DV4/NL];
|
|
@@ -4157,7 +4454,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
|
4157
4454
|
half4, \
|
|
4158
4455
|
float, \
|
|
4159
4456
|
float, float4, \
|
|
4160
|
-
|
|
4457
|
+
float4
|
|
4161
4458
|
|
|
4162
4459
|
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
|
4163
4460
|
|
|
@@ -4271,11 +4568,16 @@ kernel void kernel_cpy(
|
|
|
4271
4568
|
device const char * src0,
|
|
4272
4569
|
device char * dst,
|
|
4273
4570
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
4571
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
4274
4572
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
4275
|
-
ushort3
|
|
4573
|
+
ushort3 tptg[[threads_per_threadgroup]]) {
|
|
4276
4574
|
const int i03 = tgpig[2];
|
|
4277
4575
|
const int i02 = tgpig[1];
|
|
4278
|
-
const int i01 = tgpig[0];
|
|
4576
|
+
const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
|
|
4577
|
+
|
|
4578
|
+
if (i01 >= args.ne01) {
|
|
4579
|
+
return;
|
|
4580
|
+
}
|
|
4279
4581
|
|
|
4280
4582
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
4281
4583
|
|
|
@@ -4286,7 +4588,7 @@ kernel void kernel_cpy(
|
|
|
4286
4588
|
|
|
4287
4589
|
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
4288
4590
|
|
|
4289
|
-
for (int64_t i00 =
|
|
4591
|
+
for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
4290
4592
|
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4291
4593
|
dst_data[i00] = (T1) src[0];
|
|
4292
4594
|
}
|
|
@@ -4306,6 +4608,7 @@ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bf
|
|
|
4306
4608
|
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
|
|
4307
4609
|
#endif
|
|
4308
4610
|
|
|
4611
|
+
// TODO: templetify these kernels
|
|
4309
4612
|
kernel void kernel_cpy_f32_q8_0(
|
|
4310
4613
|
constant ggml_metal_kargs_cpy & args,
|
|
4311
4614
|
device const char * src0,
|
|
@@ -4329,23 +4632,7 @@ kernel void kernel_cpy_f32_q8_0(
|
|
|
4329
4632
|
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
|
|
4330
4633
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4331
4634
|
|
|
4332
|
-
|
|
4333
|
-
|
|
4334
|
-
for (int j = 0; j < QK8_0; j++) {
|
|
4335
|
-
const float v = src[j];
|
|
4336
|
-
amax = MAX(amax, fabs(v));
|
|
4337
|
-
}
|
|
4338
|
-
|
|
4339
|
-
const float d = amax / ((1 << 7) - 1);
|
|
4340
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4341
|
-
|
|
4342
|
-
dst_data[i00/QK8_0].d = d;
|
|
4343
|
-
|
|
4344
|
-
for (int j = 0; j < QK8_0; ++j) {
|
|
4345
|
-
const float x0 = src[j]*id;
|
|
4346
|
-
|
|
4347
|
-
dst_data[i00/QK8_0].qs[j] = round(x0);
|
|
4348
|
-
}
|
|
4635
|
+
quantize_q8_0(src, dst_data[i00/QK8_0]);
|
|
4349
4636
|
}
|
|
4350
4637
|
}
|
|
4351
4638
|
|
|
@@ -4372,32 +4659,7 @@ kernel void kernel_cpy_f32_q4_0(
|
|
|
4372
4659
|
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
|
|
4373
4660
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4374
4661
|
|
|
4375
|
-
|
|
4376
|
-
float max = 0.0f;
|
|
4377
|
-
|
|
4378
|
-
for (int j = 0; j < QK4_0; j++) {
|
|
4379
|
-
const float v = src[j];
|
|
4380
|
-
if (amax < fabs(v)) {
|
|
4381
|
-
amax = fabs(v);
|
|
4382
|
-
max = v;
|
|
4383
|
-
}
|
|
4384
|
-
}
|
|
4385
|
-
|
|
4386
|
-
const float d = max / -8;
|
|
4387
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4388
|
-
|
|
4389
|
-
dst_data[i00/QK4_0].d = d;
|
|
4390
|
-
|
|
4391
|
-
for (int j = 0; j < QK4_0/2; ++j) {
|
|
4392
|
-
const float x0 = src[0 + j]*id;
|
|
4393
|
-
const float x1 = src[QK4_0/2 + j]*id;
|
|
4394
|
-
|
|
4395
|
-
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
|
4396
|
-
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
|
4397
|
-
|
|
4398
|
-
dst_data[i00/QK4_0].qs[j] = xi0;
|
|
4399
|
-
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
|
|
4400
|
-
}
|
|
4662
|
+
quantize_q4_0(src, dst_data[i00/QK4_0]);
|
|
4401
4663
|
}
|
|
4402
4664
|
}
|
|
4403
4665
|
|
|
@@ -4424,31 +4686,7 @@ kernel void kernel_cpy_f32_q4_1(
|
|
|
4424
4686
|
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
|
|
4425
4687
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4426
4688
|
|
|
4427
|
-
|
|
4428
|
-
float max = -FLT_MAX;
|
|
4429
|
-
|
|
4430
|
-
for (int j = 0; j < QK4_1; j++) {
|
|
4431
|
-
const float v = src[j];
|
|
4432
|
-
if (min > v) min = v;
|
|
4433
|
-
if (max < v) max = v;
|
|
4434
|
-
}
|
|
4435
|
-
|
|
4436
|
-
const float d = (max - min) / ((1 << 4) - 1);
|
|
4437
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4438
|
-
|
|
4439
|
-
dst_data[i00/QK4_1].d = d;
|
|
4440
|
-
dst_data[i00/QK4_1].m = min;
|
|
4441
|
-
|
|
4442
|
-
for (int j = 0; j < QK4_1/2; ++j) {
|
|
4443
|
-
const float x0 = (src[0 + j] - min)*id;
|
|
4444
|
-
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
|
4445
|
-
|
|
4446
|
-
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
|
4447
|
-
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
|
4448
|
-
|
|
4449
|
-
dst_data[i00/QK4_1].qs[j] = xi0;
|
|
4450
|
-
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
|
|
4451
|
-
}
|
|
4689
|
+
quantize_q4_1(src, dst_data[i00/QK4_1]);
|
|
4452
4690
|
}
|
|
4453
4691
|
}
|
|
4454
4692
|
|
|
@@ -4475,38 +4713,7 @@ kernel void kernel_cpy_f32_q5_0(
|
|
|
4475
4713
|
for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
|
|
4476
4714
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4477
4715
|
|
|
4478
|
-
|
|
4479
|
-
float max = 0.0f;
|
|
4480
|
-
|
|
4481
|
-
for (int j = 0; j < QK5_0; j++) {
|
|
4482
|
-
const float v = src[j];
|
|
4483
|
-
if (amax < fabs(v)) {
|
|
4484
|
-
amax = fabs(v);
|
|
4485
|
-
max = v;
|
|
4486
|
-
}
|
|
4487
|
-
}
|
|
4488
|
-
|
|
4489
|
-
const float d = max / -16;
|
|
4490
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4491
|
-
|
|
4492
|
-
dst_data[i00/QK5_0].d = d;
|
|
4493
|
-
|
|
4494
|
-
uint32_t qh = 0;
|
|
4495
|
-
for (int j = 0; j < QK5_0/2; ++j) {
|
|
4496
|
-
const float x0 = src[0 + j]*id;
|
|
4497
|
-
const float x1 = src[QK5_0/2 + j]*id;
|
|
4498
|
-
|
|
4499
|
-
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
|
4500
|
-
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
|
4501
|
-
|
|
4502
|
-
dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
4503
|
-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
4504
|
-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
|
4505
|
-
}
|
|
4506
|
-
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
|
4507
|
-
for (int j = 0; j < 4; ++j) {
|
|
4508
|
-
dst_data[i00/QK5_0].qh[j] = qh8[j];
|
|
4509
|
-
}
|
|
4716
|
+
quantize_q5_0(src, dst_data[i00/QK5_0]);
|
|
4510
4717
|
}
|
|
4511
4718
|
}
|
|
4512
4719
|
|
|
@@ -4533,51 +4740,10 @@ kernel void kernel_cpy_f32_q5_1(
|
|
|
4533
4740
|
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
|
|
4534
4741
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4535
4742
|
|
|
4536
|
-
|
|
4537
|
-
float min = src[0];
|
|
4538
|
-
|
|
4539
|
-
for (int j = 1; j < QK5_1; j++) {
|
|
4540
|
-
const float v = src[j];
|
|
4541
|
-
min = v < min ? v : min;
|
|
4542
|
-
max = v > max ? v : max;
|
|
4543
|
-
}
|
|
4544
|
-
|
|
4545
|
-
const float d = (max - min) / 31;
|
|
4546
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4547
|
-
|
|
4548
|
-
dst_data[i00/QK5_1].d = d;
|
|
4549
|
-
dst_data[i00/QK5_1].m = min;
|
|
4550
|
-
|
|
4551
|
-
uint32_t qh = 0;
|
|
4552
|
-
for (int j = 0; j < QK5_1/2; ++j) {
|
|
4553
|
-
const float x0 = (src[0 + j] - min)*id;
|
|
4554
|
-
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
|
4555
|
-
|
|
4556
|
-
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
|
4557
|
-
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
|
4558
|
-
|
|
4559
|
-
dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
|
4560
|
-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
|
4561
|
-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
|
4562
|
-
}
|
|
4563
|
-
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
|
4564
|
-
for (int j = 0; j < 4; ++j) {
|
|
4565
|
-
dst_data[i00/QK5_1].qh[j] = qh8[j];
|
|
4566
|
-
}
|
|
4743
|
+
quantize_q5_1(src, dst_data[i00/QK5_1]);
|
|
4567
4744
|
}
|
|
4568
4745
|
}
|
|
4569
4746
|
|
|
4570
|
-
static inline int best_index_int8(int n, constant float * val, float x) {
|
|
4571
|
-
if (x <= val[0]) return 0;
|
|
4572
|
-
if (x >= val[n-1]) return n-1;
|
|
4573
|
-
int ml = 0, mu = n-1;
|
|
4574
|
-
while (mu-ml > 1) {
|
|
4575
|
-
int mav = (ml+mu)/2;
|
|
4576
|
-
if (x < val[mav]) mu = mav; else ml = mav;
|
|
4577
|
-
}
|
|
4578
|
-
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
|
4579
|
-
}
|
|
4580
|
-
|
|
4581
4747
|
kernel void kernel_cpy_f32_iq4_nl(
|
|
4582
4748
|
constant ggml_metal_kargs_cpy & args,
|
|
4583
4749
|
device const char * src0,
|
|
@@ -4601,40 +4767,7 @@ kernel void kernel_cpy_f32_iq4_nl(
|
|
|
4601
4767
|
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
|
|
4602
4768
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
4603
4769
|
|
|
4604
|
-
|
|
4605
|
-
float max = 0.0f;
|
|
4606
|
-
|
|
4607
|
-
for (int j = 0; j < QK4_NL; j++) {
|
|
4608
|
-
const float v = src[j];
|
|
4609
|
-
if (amax < fabs(v)) {
|
|
4610
|
-
amax = fabs(v);
|
|
4611
|
-
max = v;
|
|
4612
|
-
}
|
|
4613
|
-
}
|
|
4614
|
-
|
|
4615
|
-
const float d = max / kvalues_iq4nl_f[0];
|
|
4616
|
-
const float id = d ? 1.0f/d : 0.0f;
|
|
4617
|
-
|
|
4618
|
-
float sumqx = 0, sumq2 = 0;
|
|
4619
|
-
for (int j = 0; j < QK4_NL/2; ++j) {
|
|
4620
|
-
const float x0 = src[0 + j]*id;
|
|
4621
|
-
const float x1 = src[QK4_NL/2 + j]*id;
|
|
4622
|
-
|
|
4623
|
-
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
|
4624
|
-
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
|
4625
|
-
|
|
4626
|
-
dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
|
|
4627
|
-
|
|
4628
|
-
const float v0 = kvalues_iq4nl_f[xi0];
|
|
4629
|
-
const float v1 = kvalues_iq4nl_f[xi1];
|
|
4630
|
-
const float w0 = src[0 + j]*src[0 + j];
|
|
4631
|
-
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
|
4632
|
-
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
|
4633
|
-
sumq2 += w0*v0*v0 + w1*v1*v1;
|
|
4634
|
-
|
|
4635
|
-
}
|
|
4636
|
-
|
|
4637
|
-
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
|
|
4770
|
+
quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
|
|
4638
4771
|
}
|
|
4639
4772
|
}
|
|
4640
4773
|
|
|
@@ -6315,10 +6448,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
|
6315
6448
|
|
|
6316
6449
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
6317
6450
|
kernel void kernel_get_rows_q(
|
|
6451
|
+
constant ggml_metal_kargs_get_rows & args,
|
|
6318
6452
|
device const void * src0,
|
|
6319
6453
|
device const void * src1,
|
|
6320
6454
|
device float * dst,
|
|
6321
|
-
constant ggml_metal_kargs_get_rows & args,
|
|
6322
6455
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6323
6456
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
6324
6457
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
@@ -6338,10 +6471,10 @@ kernel void kernel_get_rows_q(
|
|
|
6338
6471
|
|
|
6339
6472
|
template<typename T>
|
|
6340
6473
|
kernel void kernel_get_rows_f(
|
|
6474
|
+
constant ggml_metal_kargs_get_rows & args,
|
|
6341
6475
|
device const void * src0,
|
|
6342
6476
|
device const void * src1,
|
|
6343
6477
|
device float * dst,
|
|
6344
|
-
constant ggml_metal_kargs_get_rows & args,
|
|
6345
6478
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6346
6479
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
6347
6480
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
@@ -6359,10 +6492,10 @@ kernel void kernel_get_rows_f(
|
|
|
6359
6492
|
}
|
|
6360
6493
|
|
|
6361
6494
|
kernel void kernel_get_rows_i32(
|
|
6495
|
+
constant ggml_metal_kargs_get_rows & args,
|
|
6362
6496
|
device const void * src0,
|
|
6363
6497
|
device const void * src1,
|
|
6364
6498
|
device int32_t * dst,
|
|
6365
|
-
constant ggml_metal_kargs_get_rows & args,
|
|
6366
6499
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6367
6500
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
6368
6501
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
@@ -6379,6 +6512,67 @@ kernel void kernel_get_rows_i32(
|
|
|
6379
6512
|
}
|
|
6380
6513
|
}
|
|
6381
6514
|
|
|
6515
|
+
template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
|
|
6516
|
+
kernel void kernel_set_rows_q32(
|
|
6517
|
+
constant ggml_metal_kargs_set_rows & args,
|
|
6518
|
+
device const void * src0,
|
|
6519
|
+
device const void * src1,
|
|
6520
|
+
device float * dst,
|
|
6521
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6522
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6523
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
6524
|
+
const int32_t i03 = tgpig.z;
|
|
6525
|
+
const int32_t i02 = tgpig.y;
|
|
6526
|
+
|
|
6527
|
+
const int32_t i12 = i03%args.ne12;
|
|
6528
|
+
const int32_t i11 = i02%args.ne11;
|
|
6529
|
+
|
|
6530
|
+
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
|
6531
|
+
if (i01 >= args.ne01) {
|
|
6532
|
+
return;
|
|
6533
|
+
}
|
|
6534
|
+
|
|
6535
|
+
const int32_t i10 = i01;
|
|
6536
|
+
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
|
6537
|
+
|
|
6538
|
+
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
6539
|
+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
6540
|
+
|
|
6541
|
+
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
|
6542
|
+
quantize_func(src_row + 32*ind, dst_row[ind]);
|
|
6543
|
+
}
|
|
6544
|
+
}
|
|
6545
|
+
|
|
6546
|
+
template<typename T>
|
|
6547
|
+
kernel void kernel_set_rows_f(
|
|
6548
|
+
constant ggml_metal_kargs_set_rows & args,
|
|
6549
|
+
device const void * src0,
|
|
6550
|
+
device const void * src1,
|
|
6551
|
+
device float * dst,
|
|
6552
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
6553
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
|
6554
|
+
uint3 tptg [[threads_per_threadgroup]]) {
|
|
6555
|
+
const int32_t i03 = tgpig.z;
|
|
6556
|
+
const int32_t i02 = tgpig.y;
|
|
6557
|
+
|
|
6558
|
+
const int32_t i12 = i03%args.ne12;
|
|
6559
|
+
const int32_t i11 = i02%args.ne11;
|
|
6560
|
+
|
|
6561
|
+
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
|
6562
|
+
if (i01 >= args.ne01) {
|
|
6563
|
+
return;
|
|
6564
|
+
}
|
|
6565
|
+
|
|
6566
|
+
const int32_t i10 = i01;
|
|
6567
|
+
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
|
6568
|
+
|
|
6569
|
+
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
6570
|
+
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
6571
|
+
|
|
6572
|
+
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
|
6573
|
+
dst_row[ind] = (T) src_row[ind];
|
|
6574
|
+
}
|
|
6575
|
+
}
|
|
6382
6576
|
|
|
6383
6577
|
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
|
6384
6578
|
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
|
@@ -6802,6 +6996,27 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get
|
|
|
6802
6996
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
|
6803
6997
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
|
6804
6998
|
|
|
6999
|
+
//
|
|
7000
|
+
// set rows
|
|
7001
|
+
//
|
|
7002
|
+
|
|
7003
|
+
typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
|
|
7004
|
+
|
|
7005
|
+
template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
|
|
7006
|
+
template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
|
|
7007
|
+
#if defined(GGML_METAL_USE_BF16)
|
|
7008
|
+
template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
|
|
7009
|
+
#endif
|
|
7010
|
+
|
|
7011
|
+
typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
|
|
7012
|
+
|
|
7013
|
+
template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
|
|
7014
|
+
template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
|
|
7015
|
+
template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
|
|
7016
|
+
template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
|
|
7017
|
+
template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
|
|
7018
|
+
template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
|
|
7019
|
+
|
|
6805
7020
|
//
|
|
6806
7021
|
// matrix-matrix multiplication
|
|
6807
7022
|
//
|