@novastera-oss/llamarn 0.2.5 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/RNLlamaCpp.podspec +3 -2
- package/android/CMakeLists.txt +6 -3
- package/android/src/main/cpp/include/llama.h +140 -38
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +48 -67
- package/cpp/LlamaCppModel.h +8 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +33 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +15 -28
- package/cpp/llama.cpp/common/arg.cpp +38 -12
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +9 -3
- package/cpp/llama.cpp/common/chat-parser.h +4 -1
- package/cpp/llama.cpp/common/chat.cpp +16 -13
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +52 -40
- package/cpp/llama.cpp/common/common.h +5 -2
- package/cpp/llama.cpp/common/json-partial.cpp +5 -4
- package/cpp/llama.cpp/common/json-partial.h +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.cpp +2 -1
- package/cpp/llama.cpp/common/json-schema-to-grammar.h +4 -4
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +128 -84
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/include/ggml.h +1 -3
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +49 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +10 -5
- package/cpp/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +33 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +6 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +25 -16
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -46
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/rope.cpp +118 -11
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +432 -248
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +9 -8
- package/cpp/llama.cpp/ggml/src/ggml.cpp +26 -0
- package/cpp/llama.cpp/ggml/src/gguf.cpp +19 -2
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
- package/cpp/llama.cpp/include/llama.h +140 -38
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +4 -1
- package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
- package/cpp/llama.cpp/src/llama-arch.h +7 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +289 -31
- package/cpp/llama.cpp/src/llama-batch.h +47 -17
- package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +488 -313
- package/cpp/llama.cpp/src/llama-context.h +38 -17
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +275 -152
- package/cpp/llama.cpp/src/llama-graph.h +109 -52
- package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
- package/cpp/llama.cpp/src/llama-hparams.h +8 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +281 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +133 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +1835 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +308 -0
- package/cpp/llama.cpp/src/llama-kv-cells.h +53 -17
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +1116 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.h +188 -0
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +89 -4
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model.cpp +735 -143
- package/cpp/llama.cpp/src/llama-model.h +4 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
- package/cpp/llama.cpp/src/llama-vocab.cpp +39 -25
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/llama.cpp/vendor/cpp-httplib/httplib.h +10518 -0
- package/cpp/llama.cpp/vendor/miniaudio/miniaudio.h +93468 -0
- package/cpp/llama.cpp/{common → vendor}/minja/chat-template.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor}/minja/minja.hpp +1 -1
- package/cpp/llama.cpp/{common → vendor/nlohmann}/json.hpp +3027 -2267
- package/cpp/llama.cpp/vendor/nlohmann/json_fwd.hpp +187 -0
- package/cpp/llama.cpp/vendor/stb/stb_image.h +7988 -0
- package/cpp/rn-completion.cpp +65 -10
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/cpp/{rn-utils.hpp → rn-utils.h} +8 -1
- package/ios/include/chat.h +1 -1
- package/ios/include/common/minja/chat-template.hpp +1 -1
- package/ios/include/common/minja/minja.hpp +1 -1
- package/ios/include/common.h +5 -2
- package/ios/include/json-schema-to-grammar.h +4 -4
- package/ios/include/llama.h +140 -38
- package/ios/include/{common → nlohmann}/json.hpp +3027 -2267
- package/ios/libs/llama.xcframework/Info.plist +20 -20
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4617
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3557
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4638
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3559
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4616
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4637
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3556
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4653
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4674
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3587
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/ggml.h +1 -3
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +140 -38
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -2747
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -502
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#define GGML_COMMON_DECL_CPP
|
|
4
|
+
#include "ggml-common.h"
|
|
5
|
+
|
|
6
|
+
#include "traits.h"
|
|
7
|
+
#include "ggml.h"
|
|
8
|
+
|
|
9
|
+
// GGML internal header
|
|
10
|
+
|
|
11
|
+
ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void);
|
|
12
|
+
|
|
13
|
+
template <int K> constexpr int QK_0() {
|
|
14
|
+
if constexpr (K == 4) {
|
|
15
|
+
return QK4_0;
|
|
16
|
+
}
|
|
17
|
+
if constexpr (K == 8) {
|
|
18
|
+
return QK8_0;
|
|
19
|
+
}
|
|
20
|
+
return -1;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
template <int K, int N> struct block {
|
|
24
|
+
ggml_half d[N]; // deltas for N qK_0 blocks
|
|
25
|
+
int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
// control size
|
|
29
|
+
static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
|
|
30
|
+
static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
|
|
31
|
+
static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
|
|
32
|
+
static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
|
|
33
|
+
|
|
34
|
+
using block_q4_0x4 = block<4, 4>;
|
|
35
|
+
using block_q4_0x8 = block<4, 8>;
|
|
36
|
+
using block_q8_0x4 = block<8, 4>;
|
|
37
|
+
using block_q8_0x8 = block<8, 8>;
|
|
38
|
+
|
|
39
|
+
struct block_q4_Kx8 {
|
|
40
|
+
ggml_half d[8]; // super-block scale for quantized scales
|
|
41
|
+
ggml_half dmin[8]; // super-block scale for quantized mins
|
|
42
|
+
uint8_t scales[96]; // scales and mins, quantized with 6 bits
|
|
43
|
+
uint8_t qs[1024]; // 4--bit quants
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
|
47
|
+
|
|
48
|
+
struct block_q8_Kx4 {
|
|
49
|
+
float d[4]; // delta
|
|
50
|
+
int8_t qs[QK_K * 4]; // quants
|
|
51
|
+
int16_t bsums[QK_K / 4]; // sum of quants in groups of 16
|
|
52
|
+
};
|
|
53
|
+
|
|
54
|
+
static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding");
|
|
55
|
+
|
|
56
|
+
struct block_iq4_nlx4 {
|
|
57
|
+
ggml_half d[4]; // deltas for 4 iq4_nl blocks
|
|
58
|
+
uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
|
|
62
|
+
|
|
63
|
+
#if defined(__cplusplus)
|
|
64
|
+
extern "C" {
|
|
65
|
+
#endif
|
|
66
|
+
|
|
67
|
+
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
|
68
|
+
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
|
69
|
+
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
|
70
|
+
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
71
|
+
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
72
|
+
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
73
|
+
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
74
|
+
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
75
|
+
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
76
|
+
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
77
|
+
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
78
|
+
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
79
|
+
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
80
|
+
|
|
81
|
+
// Native implementations
|
|
82
|
+
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
|
83
|
+
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
|
84
|
+
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
|
85
|
+
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
86
|
+
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
87
|
+
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
88
|
+
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
89
|
+
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
90
|
+
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
91
|
+
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
92
|
+
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
93
|
+
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
94
|
+
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
95
|
+
|
|
96
|
+
#if defined(__cplusplus)
|
|
97
|
+
} // extern "C"
|
|
98
|
+
#endif
|
|
@@ -944,10 +944,8 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
|
|
944
944
|
for (int i = 0; i < offset; ++i) { \
|
|
945
945
|
x[i] = vec_add(x[i], x[offset + i]); \
|
|
946
946
|
} \
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
vec_extract(x[0], 2) + \
|
|
950
|
-
vec_extract(x[0], 3); \
|
|
947
|
+
float32x4_t tmp = x[0] + vec_reve(x[0]); \
|
|
948
|
+
res = tmp[0] + tmp[1]; \
|
|
951
949
|
}
|
|
952
950
|
|
|
953
951
|
#define GGML_F32_VEC GGML_F32x4
|
|
@@ -207,9 +207,9 @@ typedef float2 dfloat2;
|
|
|
207
207
|
#define FP16_MMA_AVAILABLE
|
|
208
208
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
209
209
|
|
|
210
|
-
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
|
|
210
|
+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
|
|
211
211
|
#define FP16_MMA_AVAILABLE
|
|
212
|
-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
|
|
212
|
+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
|
|
213
213
|
|
|
214
214
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
|
215
215
|
#define NEW_MMA_AVAILABLE
|
|
@@ -262,11 +262,11 @@ static bool cp_async_available(const int cc) {
|
|
|
262
262
|
}
|
|
263
263
|
|
|
264
264
|
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
|
265
|
-
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
266
|
-
return
|
|
265
|
+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
|
|
266
|
+
return 64;
|
|
267
267
|
#else
|
|
268
268
|
return 32;
|
|
269
|
-
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
269
|
+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
|
|
270
270
|
}
|
|
271
271
|
|
|
272
272
|
[[noreturn]]
|
|
@@ -466,9 +466,6 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
|
|
466
466
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
467
467
|
}
|
|
468
468
|
|
|
469
|
-
// TODO: move to ggml-common.h
|
|
470
|
-
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
|
471
|
-
|
|
472
469
|
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
|
473
470
|
|
|
474
471
|
static __device__ __forceinline__ float get_alibi_slope(
|
|
@@ -635,6 +632,7 @@ struct ggml_cuda_device_info {
|
|
|
635
632
|
int nsm; // number of streaming multiprocessors
|
|
636
633
|
size_t smpb; // max. shared memory per block
|
|
637
634
|
size_t smpbo; // max. shared memory per block (with opt-in)
|
|
635
|
+
bool integrated; // Device is integrated as opposed to discrete
|
|
638
636
|
bool vmm; // virtual memory support
|
|
639
637
|
size_t vmm_granularity; // granularity of virtual memory
|
|
640
638
|
size_t total_vram;
|
|
@@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
652
652
|
float KQ_max_scale[cols_per_thread];
|
|
653
653
|
#pragma unroll
|
|
654
654
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
655
|
-
|
|
655
|
+
const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
|
|
656
|
+
KQ_max_scale[col] = expf(KQ_max_diff);
|
|
656
657
|
KQ_max[col] = KQ_max_new[col];
|
|
657
658
|
|
|
659
|
+
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
|
660
|
+
|
|
658
661
|
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
|
659
662
|
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
|
|
660
663
|
}
|
|
@@ -1246,7 +1249,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
1246
1249
|
NO_DEVICE_CODE;
|
|
1247
1250
|
return;
|
|
1248
1251
|
}
|
|
1249
|
-
#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
1252
|
+
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
1250
1253
|
|
|
1251
1254
|
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
|
1252
1255
|
|
|
@@ -243,10 +243,10 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|
|
243
243
|
|
|
244
244
|
info.default_tensor_split[id] = total_vram;
|
|
245
245
|
total_vram += prop.totalGlobalMem;
|
|
246
|
-
|
|
247
|
-
info.devices[id].nsm
|
|
248
|
-
info.devices[id].smpb
|
|
249
|
-
info.devices[id].warp_size
|
|
246
|
+
info.devices[id].integrated = prop.integrated;
|
|
247
|
+
info.devices[id].nsm = prop.multiProcessorCount;
|
|
248
|
+
info.devices[id].smpb = prop.sharedMemPerBlock;
|
|
249
|
+
info.devices[id].warp_size = prop.warpSize;
|
|
250
250
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
251
251
|
info.devices[id].smpbo = prop.sharedMemPerBlock;
|
|
252
252
|
|
|
@@ -615,9 +615,8 @@ static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
|
|
|
615
615
|
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
|
616
616
|
|
|
617
617
|
ggml_cuda_set_device(ctx->device);
|
|
618
|
-
CUDA_CHECK(
|
|
619
|
-
CUDA_CHECK(
|
|
620
|
-
CUDA_CHECK(cudaDeviceSynchronize());
|
|
618
|
+
CUDA_CHECK(cudaMemsetAsync(ctx->dev_ptr, value, buffer->size, cudaStreamPerThread));
|
|
619
|
+
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
|
|
621
620
|
}
|
|
622
621
|
|
|
623
622
|
static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
|
|
@@ -1065,6 +1064,10 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_
|
|
|
1065
1064
|
GGML_UNUSED(buft);
|
|
1066
1065
|
}
|
|
1067
1066
|
|
|
1067
|
+
static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
|
|
1068
|
+
return buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
|
|
1069
|
+
}
|
|
1070
|
+
|
|
1068
1071
|
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
1069
1072
|
CUDA_CHECK(cudaFreeHost(buffer->context));
|
|
1070
1073
|
}
|
|
@@ -1140,7 +1143,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)(
|
|
|
1140
1143
|
static cudaError_t ggml_cuda_cpy_tensor_2d(
|
|
1141
1144
|
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
|
|
1142
1145
|
|
|
1143
|
-
GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
|
|
1144
1146
|
const char * src_ptr = (const char *) src->data;
|
|
1145
1147
|
char * dst_ptr = (char *) dst;
|
|
1146
1148
|
|
|
@@ -1423,8 +1425,6 @@ static void ggml_cuda_op_mul_mat(
|
|
|
1423
1425
|
const int64_t nb2 = dst->nb[2];
|
|
1424
1426
|
const int64_t nb3 = dst->nb[3];
|
|
1425
1427
|
|
|
1426
|
-
GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
|
|
1427
|
-
GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
|
|
1428
1428
|
ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
|
|
1429
1429
|
ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
|
|
1430
1430
|
|
|
@@ -1746,7 +1746,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
|
1746
1746
|
GGML_ASSERT(!ggml_is_transposed(src0));
|
|
1747
1747
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
|
1748
1748
|
|
|
1749
|
-
GGML_ASSERT(
|
|
1749
|
+
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
|
|
1750
1750
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
1751
1751
|
|
|
1752
1752
|
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
|
|
@@ -2641,6 +2641,8 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
|
|
2641
2641
|
|
|
2642
2642
|
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
|
2643
2643
|
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
|
|
2644
|
+
// flag used to determine whether it is an integrated_gpu
|
|
2645
|
+
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
|
|
2644
2646
|
|
|
2645
2647
|
while (!graph_evaluated_or_captured) {
|
|
2646
2648
|
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
|
@@ -2659,10 +2661,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|
|
2659
2661
|
if (node->src[j] != nullptr) {
|
|
2660
2662
|
assert(node->src[j]->buffer);
|
|
2661
2663
|
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
|
|
2662
|
-
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
|
|
2664
|
+
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
|
|
2663
2665
|
}
|
|
2664
2666
|
}
|
|
2665
|
-
#
|
|
2667
|
+
#else
|
|
2668
|
+
GGML_UNUSED(integrated);
|
|
2669
|
+
#endif // NDEBUG
|
|
2666
2670
|
|
|
2667
2671
|
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
|
|
2668
2672
|
if (!ok) {
|
|
@@ -2994,9 +2998,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
2994
2998
|
{
|
|
2995
2999
|
struct ggml_tensor * a = op->src[0];
|
|
2996
3000
|
struct ggml_tensor * b = op->src[1];
|
|
2997
|
-
// for small weight matrices the active device can end up without any rows, don't use row split in those cases
|
|
2998
|
-
// this avoids some edge cases (and the performance would not be good anyways)
|
|
2999
3001
|
if (a->buffer && ggml_backend_buft_is_cuda_split(a->buffer->buft)) {
|
|
3002
|
+
if (a->ne[2] > 1 || a->ne[3] > 1) {
|
|
3003
|
+
return false;
|
|
3004
|
+
}
|
|
3005
|
+
// for small weight matrices the active device can end up without any rows, don't use row split in those cases
|
|
3006
|
+
// this avoids some edge cases (and the performance would not be good anyways)
|
|
3000
3007
|
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) a->buffer->buft->context;
|
|
3001
3008
|
int64_t row_low;
|
|
3002
3009
|
int64_t row_high;
|
|
@@ -3263,7 +3270,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
|
3263
3270
|
}
|
|
3264
3271
|
|
|
3265
3272
|
static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
|
3266
|
-
|
|
3273
|
+
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
|
|
3274
|
+
const bool integrated = ggml_cuda_info().devices[dev_ctx->device].integrated;
|
|
3275
|
+
return (((ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev) || (integrated && ggml_backend_buft_is_cuda_host(buft)));
|
|
3267
3276
|
}
|
|
3268
3277
|
|
|
3269
3278
|
static int64_t get_op_batch_size(const ggml_tensor * op) {
|
|
@@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
|
|
|
10
10
|
float * __restrict__ dst, const int64_t L) {
|
|
11
11
|
GGML_UNUSED(src1_nb0);
|
|
12
12
|
GGML_UNUSED(src2_nb0);
|
|
13
|
+
|
|
14
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
13
15
|
const int bidx = blockIdx.x; // split along B
|
|
14
16
|
const int bidy = blockIdx.y; // split along D
|
|
15
17
|
const int tid = threadIdx.x;
|
|
@@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
|
|
|
44
46
|
if (N == 16) {
|
|
45
47
|
#pragma unroll
|
|
46
48
|
for (size_t i = 0; i < splitD / 4; i += 2) {
|
|
47
|
-
float value = A_block[(wid *
|
|
49
|
+
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
|
|
48
50
|
// todo: bank conflict
|
|
49
51
|
// I am always confused with how to use the swizzling method to solve
|
|
50
52
|
// bank conflit. Hoping somebody can tell me.
|
|
51
|
-
smem_A[(wid *
|
|
53
|
+
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
|
52
54
|
}
|
|
53
55
|
#pragma unroll
|
|
54
56
|
for (size_t i = 0; i < splitD / 4; i += 2) {
|
|
55
|
-
float value = s0_block[(wid *
|
|
56
|
-
smem_s0[(wid *
|
|
57
|
+
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
|
|
58
|
+
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
|
57
59
|
}
|
|
58
60
|
}
|
|
59
61
|
|
|
@@ -113,6 +113,10 @@ if (GGML_HIP_ROCWMMA_FATTN)
|
|
|
113
113
|
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
|
|
114
114
|
endif()
|
|
115
115
|
|
|
116
|
+
if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
|
|
117
|
+
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
|
118
|
+
endif()
|
|
119
|
+
|
|
116
120
|
if (NOT GGML_CUDA_FA)
|
|
117
121
|
add_compile_definitions(GGML_CUDA_NO_FA)
|
|
118
122
|
endif()
|
|
@@ -44,21 +44,22 @@ if (GGML_METAL_EMBED_LIBRARY)
|
|
|
44
44
|
set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
|
|
45
45
|
|
|
46
46
|
add_custom_command(
|
|
47
|
-
OUTPUT ${METALLIB_EMBED_ASM}
|
|
47
|
+
OUTPUT "${METALLIB_EMBED_ASM}"
|
|
48
48
|
COMMAND echo "Embedding Metal library"
|
|
49
|
-
COMMAND sed -e
|
|
50
|
-
COMMAND sed -e
|
|
51
|
-
COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM}
|
|
52
|
-
COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM}
|
|
53
|
-
COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM}
|
|
54
|
-
COMMAND echo
|
|
55
|
-
COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM}
|
|
56
|
-
COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM}
|
|
49
|
+
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}"
|
|
50
|
+
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}"
|
|
51
|
+
COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}"
|
|
52
|
+
COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}"
|
|
53
|
+
COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}"
|
|
54
|
+
COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}"
|
|
55
|
+
COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}"
|
|
56
|
+
COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}"
|
|
57
57
|
DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
|
|
58
58
|
COMMENT "Generate assembly for embedded Metal library"
|
|
59
|
+
VERBATIM
|
|
59
60
|
)
|
|
60
61
|
|
|
61
|
-
target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM})
|
|
62
|
+
target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
|
|
62
63
|
else()
|
|
63
64
|
if (GGML_METAL_SHADER_DEBUG)
|
|
64
65
|
# custom command to do the following:
|
|
@@ -498,6 +498,7 @@ enum ggml_metal_kernel_type {
|
|
|
498
498
|
GGML_METAL_KERNEL_TYPE_COS,
|
|
499
499
|
GGML_METAL_KERNEL_TYPE_NEG,
|
|
500
500
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
501
|
+
GGML_METAL_KERNEL_TYPE_MEAN,
|
|
501
502
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
502
503
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
|
503
504
|
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
|
@@ -1454,6 +1455,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1454
1455
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
1455
1456
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
|
1456
1457
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
1458
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
|
1457
1459
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
1458
1460
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
|
1459
1461
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
|
@@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1653
1655
|
case GGML_OP_LOG:
|
|
1654
1656
|
return false; // TODO: implement
|
|
1655
1657
|
case GGML_OP_SUM_ROWS:
|
|
1658
|
+
case GGML_OP_MEAN:
|
|
1656
1659
|
case GGML_OP_SOFT_MAX:
|
|
1657
1660
|
case GGML_OP_GROUP_NORM:
|
|
1658
1661
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
|
@@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node(
|
|
|
2400
2403
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2401
2404
|
} break;
|
|
2402
2405
|
case GGML_OP_SUM_ROWS:
|
|
2406
|
+
case GGML_OP_MEAN:
|
|
2403
2407
|
{
|
|
2404
2408
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
|
2405
2409
|
|
|
2406
|
-
id<MTLComputePipelineState> pipeline =
|
|
2410
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2411
|
+
|
|
2412
|
+
switch (dst->op) {
|
|
2413
|
+
case GGML_OP_SUM_ROWS:
|
|
2414
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
|
2415
|
+
break;
|
|
2416
|
+
case GGML_OP_MEAN:
|
|
2417
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
|
2418
|
+
break;
|
|
2419
|
+
default:
|
|
2420
|
+
GGML_ABORT("fatal error");
|
|
2421
|
+
}
|
|
2422
|
+
|
|
2423
|
+
int nth = 32; // SIMD width
|
|
2424
|
+
|
|
2425
|
+
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
2426
|
+
nth *= 2;
|
|
2427
|
+
}
|
|
2407
2428
|
|
|
2429
|
+
nth = MIN(nth, ne00);
|
|
2408
2430
|
|
|
2409
2431
|
ggml_metal_kargs_sum_rows args = {
|
|
2410
2432
|
/*.ne00 =*/ ne00,
|
|
@@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node(
|
|
|
2434
2456
|
};
|
|
2435
2457
|
|
|
2436
2458
|
[encoder setComputePipelineState:pipeline];
|
|
2437
|
-
[encoder
|
|
2438
|
-
[encoder setBuffer:
|
|
2439
|
-
[encoder
|
|
2459
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
2460
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2461
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
2462
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
2440
2463
|
|
|
2441
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(
|
|
2464
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2442
2465
|
} break;
|
|
2443
2466
|
case GGML_OP_SOFT_MAX:
|
|
2444
2467
|
{
|
|
@@ -4766,6 +4789,8 @@ static bool ggml_metal_encode_node(
|
|
|
4766
4789
|
GGML_ASSERT(nqptg % 8 == 0);
|
|
4767
4790
|
GGML_ASSERT(ncpsg % 32 == 0);
|
|
4768
4791
|
|
|
4792
|
+
const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
|
|
4793
|
+
|
|
4769
4794
|
// 2*(2*ncpsg + nqptg)*(nsg)
|
|
4770
4795
|
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
|
4771
4796
|
//
|
|
@@ -4773,7 +4798,7 @@ static bool ggml_metal_encode_node(
|
|
|
4773
4798
|
// the shared memory needed for the simdgroups to load the KV cache
|
|
4774
4799
|
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
|
4775
4800
|
//
|
|
4776
|
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
|
4801
|
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
|
|
4777
4802
|
|
|
4778
4803
|
int64_t nsgmax = 2;
|
|
4779
4804
|
|
|
@@ -4810,9 +4835,9 @@ static bool ggml_metal_encode_node(
|
|
|
4810
4835
|
// and store the soft_max values and the mask
|
|
4811
4836
|
//
|
|
4812
4837
|
// ne00*(nsg)
|
|
4813
|
-
// each simdgroup has a full
|
|
4838
|
+
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
|
4814
4839
|
//
|
|
4815
|
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
|
|
4840
|
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
|
|
4816
4841
|
|
|
4817
4842
|
int64_t nsgmax = 2;
|
|
4818
4843
|
while (true) {
|