@novastera-oss/llamarn 0.2.6 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/include/llama.h +134 -36
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/LlamaCppModel.cpp +2 -2
- package/cpp/LlamaCppModel.h +3 -3
- package/cpp/PureCppImpl.cpp +1 -1
- package/cpp/PureCppImpl.h +2 -2
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakeLists.txt +15 -4
- package/cpp/llama.cpp/Makefile +2 -2
- package/cpp/llama.cpp/README.md +32 -13
- package/cpp/llama.cpp/common/CMakeLists.txt +10 -20
- package/cpp/llama.cpp/common/arg.cpp +30 -6
- package/cpp/llama.cpp/common/build-info.cpp.in +2 -2
- package/cpp/llama.cpp/common/chat-parser.cpp +5 -0
- package/cpp/llama.cpp/common/chat-parser.h +2 -0
- package/cpp/llama.cpp/common/chat.cpp +12 -9
- package/cpp/llama.cpp/common/chat.h +1 -1
- package/cpp/llama.cpp/common/common.cpp +50 -40
- package/cpp/llama.cpp/common/common.h +5 -2
- package/cpp/llama.cpp/common/speculative.cpp +6 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +97 -56
- package/cpp/llama.cpp/ggml/CMakeLists.txt +47 -2
- package/cpp/llama.cpp/ggml/cmake/common.cmake +1 -2
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +47 -13
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cann/common.h +6 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
- package/cpp/llama.cpp/ggml/src/ggml-common.h +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +93 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +4113 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/arm/repack.cpp +2174 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +2638 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +2731 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +2068 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/riscv/repack.cpp +396 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +1299 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/wasm/quants.c +1480 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/x86/quants.c +4310 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +59 -3206
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/common.h +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +7 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +10 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +8 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +56 -7
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/quants.c +1157 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +1555 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.h +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +2 -4
- package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +5 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +4 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +6 -8
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +11 -10
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +33 -8
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +135 -100
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +908 -3
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
- package/cpp/llama.cpp/ggml/src/ggml-quants.c +0 -2
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
- package/cpp/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/common.hpp +19 -24
- package/cpp/llama.cpp/ggml/src/ggml-sycl/convert.cpp +21 -2
- package/cpp/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +121 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +32 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +2 -96
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +164 -38
- package/cpp/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +32 -8
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +38 -10
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +26 -29
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +431 -247
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -12
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +0 -6
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +57 -0
- package/cpp/llama.cpp/gguf-py/gguf/gguf_writer.py +4 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +14 -3
- package/cpp/llama.cpp/include/llama.h +134 -36
- package/cpp/llama.cpp/requirements/requirements-compare-llama-bench.txt +1 -0
- package/cpp/llama.cpp/src/CMakeLists.txt +2 -2
- package/cpp/llama.cpp/src/llama-arch.cpp +95 -3
- package/cpp/llama.cpp/src/llama-arch.h +7 -1
- package/cpp/llama.cpp/src/llama-batch.cpp +270 -19
- package/cpp/llama.cpp/src/llama-batch.h +36 -11
- package/cpp/llama.cpp/src/llama-chat.cpp +19 -2
- package/cpp/llama.cpp/src/llama-chat.h +1 -0
- package/cpp/llama.cpp/src/llama-context.cpp +313 -213
- package/cpp/llama.cpp/src/llama-context.h +16 -12
- package/cpp/llama.cpp/src/llama-cparams.cpp +1 -1
- package/cpp/llama.cpp/src/llama-cparams.h +1 -1
- package/cpp/llama.cpp/src/llama-graph.cpp +249 -129
- package/cpp/llama.cpp/src/llama-graph.h +90 -34
- package/cpp/llama.cpp/src/llama-hparams.cpp +6 -2
- package/cpp/llama.cpp/src/llama-hparams.h +8 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +82 -50
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +23 -26
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +292 -174
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +68 -38
- package/cpp/llama.cpp/src/llama-kv-cells.h +18 -13
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +247 -0
- package/cpp/llama.cpp/src/llama-memory-hybrid.h +143 -0
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +266 -282
- package/cpp/llama.cpp/src/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +54 -57
- package/cpp/llama.cpp/src/llama-memory.cpp +41 -0
- package/cpp/llama.cpp/src/llama-memory.h +64 -23
- package/cpp/llama.cpp/src/llama-mmap.cpp +1 -1
- package/cpp/llama.cpp/src/llama-model-loader.cpp +42 -17
- package/cpp/llama.cpp/src/llama-model.cpp +726 -141
- package/cpp/llama.cpp/src/llama-model.h +4 -0
- package/cpp/llama.cpp/src/llama-quant.cpp +2 -1
- package/cpp/llama.cpp/src/llama-vocab.cpp +32 -23
- package/cpp/llama.cpp/src/llama.cpp +11 -7
- package/cpp/llama.cpp/src/unicode.cpp +5 -0
- package/cpp/rn-completion.cpp +2 -2
- package/cpp/{rn-llama.hpp → rn-llama.h} +1 -1
- package/ios/include/chat.h +1 -1
- package/ios/include/common.h +5 -2
- package/ios/include/llama.h +134 -36
- package/ios/libs/llama.xcframework/Info.plist +18 -18
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3744 -3624
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4863 -4689
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4834 -4710
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3742 -3622
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4900 -4725
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +4871 -4746
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +3773 -3652
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +134 -36
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +1 -2
- package/cpp/llama.cpp/common/cmake/build-info-gen-cpp.cmake +0 -24
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13891
- package/cpp/llama.cpp/src/llama-kv-cache.cpp +0 -1
- package/cpp/llama.cpp/src/llama-kv-cache.h +0 -44
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
- /package/cpp/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /package/cpp/{rn-utils.hpp → rn-utils.h} +0 -0
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#define GGML_COMMON_DECL_CPP
|
|
4
|
+
#include "ggml-common.h"
|
|
5
|
+
|
|
6
|
+
#include "traits.h"
|
|
7
|
+
#include "ggml.h"
|
|
8
|
+
|
|
9
|
+
// GGML internal header
|
|
10
|
+
|
|
11
|
+
ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void);
|
|
12
|
+
|
|
13
|
+
template <int K> constexpr int QK_0() {
|
|
14
|
+
if constexpr (K == 4) {
|
|
15
|
+
return QK4_0;
|
|
16
|
+
}
|
|
17
|
+
if constexpr (K == 8) {
|
|
18
|
+
return QK8_0;
|
|
19
|
+
}
|
|
20
|
+
return -1;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
template <int K, int N> struct block {
|
|
24
|
+
ggml_half d[N]; // deltas for N qK_0 blocks
|
|
25
|
+
int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
// control size
|
|
29
|
+
static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
|
|
30
|
+
static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
|
|
31
|
+
static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
|
|
32
|
+
static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
|
|
33
|
+
|
|
34
|
+
using block_q4_0x4 = block<4, 4>;
|
|
35
|
+
using block_q4_0x8 = block<4, 8>;
|
|
36
|
+
using block_q8_0x4 = block<8, 4>;
|
|
37
|
+
using block_q8_0x8 = block<8, 8>;
|
|
38
|
+
|
|
39
|
+
struct block_q4_Kx8 {
|
|
40
|
+
ggml_half d[8]; // super-block scale for quantized scales
|
|
41
|
+
ggml_half dmin[8]; // super-block scale for quantized mins
|
|
42
|
+
uint8_t scales[96]; // scales and mins, quantized with 6 bits
|
|
43
|
+
uint8_t qs[1024]; // 4--bit quants
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
|
47
|
+
|
|
48
|
+
struct block_q8_Kx4 {
|
|
49
|
+
float d[4]; // delta
|
|
50
|
+
int8_t qs[QK_K * 4]; // quants
|
|
51
|
+
int16_t bsums[QK_K / 4]; // sum of quants in groups of 16
|
|
52
|
+
};
|
|
53
|
+
|
|
54
|
+
static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding");
|
|
55
|
+
|
|
56
|
+
struct block_iq4_nlx4 {
|
|
57
|
+
ggml_half d[4]; // deltas for 4 iq4_nl blocks
|
|
58
|
+
uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
|
|
62
|
+
|
|
63
|
+
#if defined(__cplusplus)
|
|
64
|
+
extern "C" {
|
|
65
|
+
#endif
|
|
66
|
+
|
|
67
|
+
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
|
68
|
+
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
|
69
|
+
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
|
70
|
+
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
71
|
+
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
72
|
+
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
73
|
+
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
74
|
+
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
75
|
+
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
76
|
+
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
77
|
+
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
78
|
+
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
79
|
+
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
80
|
+
|
|
81
|
+
// Native implementations
|
|
82
|
+
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
|
83
|
+
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
|
84
|
+
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
|
85
|
+
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
86
|
+
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
87
|
+
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
88
|
+
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
89
|
+
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
90
|
+
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
91
|
+
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
92
|
+
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
93
|
+
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
94
|
+
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
|
95
|
+
|
|
96
|
+
#if defined(__cplusplus)
|
|
97
|
+
} // extern "C"
|
|
98
|
+
#endif
|
|
@@ -944,10 +944,8 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
|
|
944
944
|
for (int i = 0; i < offset; ++i) { \
|
|
945
945
|
x[i] = vec_add(x[i], x[offset + i]); \
|
|
946
946
|
} \
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
vec_extract(x[0], 2) + \
|
|
950
|
-
vec_extract(x[0], 3); \
|
|
947
|
+
float32x4_t tmp = x[0] + vec_reve(x[0]); \
|
|
948
|
+
res = tmp[0] + tmp[1]; \
|
|
951
949
|
}
|
|
952
950
|
|
|
953
951
|
#define GGML_F32_VEC GGML_F32x4
|
|
@@ -207,9 +207,9 @@ typedef float2 dfloat2;
|
|
|
207
207
|
#define FP16_MMA_AVAILABLE
|
|
208
208
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
|
209
209
|
|
|
210
|
-
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
|
|
210
|
+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
|
|
211
211
|
#define FP16_MMA_AVAILABLE
|
|
212
|
-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
|
|
212
|
+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
|
|
213
213
|
|
|
214
214
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
|
215
215
|
#define NEW_MMA_AVAILABLE
|
|
@@ -262,11 +262,11 @@ static bool cp_async_available(const int cc) {
|
|
|
262
262
|
}
|
|
263
263
|
|
|
264
264
|
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
|
265
|
-
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
266
|
-
return
|
|
265
|
+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
|
|
266
|
+
return 64;
|
|
267
267
|
#else
|
|
268
268
|
return 32;
|
|
269
|
-
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
269
|
+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
|
|
270
270
|
}
|
|
271
271
|
|
|
272
272
|
[[noreturn]]
|
|
@@ -466,9 +466,6 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
|
|
466
466
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
|
467
467
|
}
|
|
468
468
|
|
|
469
|
-
// TODO: move to ggml-common.h
|
|
470
|
-
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
|
471
|
-
|
|
472
469
|
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
|
473
470
|
|
|
474
471
|
static __device__ __forceinline__ float get_alibi_slope(
|
|
@@ -652,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
652
652
|
float KQ_max_scale[cols_per_thread];
|
|
653
653
|
#pragma unroll
|
|
654
654
|
for (int col = 0; col < cols_per_thread; ++col) {
|
|
655
|
-
|
|
655
|
+
const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
|
|
656
|
+
KQ_max_scale[col] = expf(KQ_max_diff);
|
|
656
657
|
KQ_max[col] = KQ_max_new[col];
|
|
657
658
|
|
|
659
|
+
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
|
660
|
+
|
|
658
661
|
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
|
659
662
|
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
|
|
660
663
|
}
|
|
@@ -615,9 +615,8 @@ static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
|
|
|
615
615
|
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
|
|
616
616
|
|
|
617
617
|
ggml_cuda_set_device(ctx->device);
|
|
618
|
-
CUDA_CHECK(
|
|
619
|
-
CUDA_CHECK(
|
|
620
|
-
CUDA_CHECK(cudaDeviceSynchronize());
|
|
618
|
+
CUDA_CHECK(cudaMemsetAsync(ctx->dev_ptr, value, buffer->size, cudaStreamPerThread));
|
|
619
|
+
CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
|
|
621
620
|
}
|
|
622
621
|
|
|
623
622
|
static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
|
|
@@ -1144,7 +1143,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)(
|
|
|
1144
1143
|
static cudaError_t ggml_cuda_cpy_tensor_2d(
|
|
1145
1144
|
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
|
|
1146
1145
|
|
|
1147
|
-
GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
|
|
1148
1146
|
const char * src_ptr = (const char *) src->data;
|
|
1149
1147
|
char * dst_ptr = (char *) dst;
|
|
1150
1148
|
|
|
@@ -1427,8 +1425,6 @@ static void ggml_cuda_op_mul_mat(
|
|
|
1427
1425
|
const int64_t nb2 = dst->nb[2];
|
|
1428
1426
|
const int64_t nb3 = dst->nb[3];
|
|
1429
1427
|
|
|
1430
|
-
GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
|
|
1431
|
-
GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
|
|
1432
1428
|
ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
|
|
1433
1429
|
ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
|
|
1434
1430
|
|
|
@@ -1750,7 +1746,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
|
1750
1746
|
GGML_ASSERT(!ggml_is_transposed(src0));
|
|
1751
1747
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
|
1752
1748
|
|
|
1753
|
-
GGML_ASSERT(
|
|
1749
|
+
GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
|
|
1754
1750
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
1755
1751
|
|
|
1756
1752
|
// Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
|
|
@@ -2668,7 +2664,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|
|
2668
2664
|
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
|
|
2669
2665
|
}
|
|
2670
2666
|
}
|
|
2671
|
-
#
|
|
2667
|
+
#else
|
|
2668
|
+
GGML_UNUSED(integrated);
|
|
2669
|
+
#endif // NDEBUG
|
|
2672
2670
|
|
|
2673
2671
|
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
|
|
2674
2672
|
if (!ok) {
|
|
@@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
|
|
|
10
10
|
float * __restrict__ dst, const int64_t L) {
|
|
11
11
|
GGML_UNUSED(src1_nb0);
|
|
12
12
|
GGML_UNUSED(src2_nb0);
|
|
13
|
+
|
|
14
|
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
13
15
|
const int bidx = blockIdx.x; // split along B
|
|
14
16
|
const int bidy = blockIdx.y; // split along D
|
|
15
17
|
const int tid = threadIdx.x;
|
|
@@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
|
|
|
44
46
|
if (N == 16) {
|
|
45
47
|
#pragma unroll
|
|
46
48
|
for (size_t i = 0; i < splitD / 4; i += 2) {
|
|
47
|
-
float value = A_block[(wid *
|
|
49
|
+
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
|
|
48
50
|
// todo: bank conflict
|
|
49
51
|
// I am always confused with how to use the swizzling method to solve
|
|
50
52
|
// bank conflit. Hoping somebody can tell me.
|
|
51
|
-
smem_A[(wid *
|
|
53
|
+
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
|
52
54
|
}
|
|
53
55
|
#pragma unroll
|
|
54
56
|
for (size_t i = 0; i < splitD / 4; i += 2) {
|
|
55
|
-
float value = s0_block[(wid *
|
|
56
|
-
smem_s0[(wid *
|
|
57
|
+
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
|
|
58
|
+
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
|
57
59
|
}
|
|
58
60
|
}
|
|
59
61
|
|
|
@@ -113,6 +113,10 @@ if (GGML_HIP_ROCWMMA_FATTN)
|
|
|
113
113
|
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
|
|
114
114
|
endif()
|
|
115
115
|
|
|
116
|
+
if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
|
|
117
|
+
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
|
118
|
+
endif()
|
|
119
|
+
|
|
116
120
|
if (NOT GGML_CUDA_FA)
|
|
117
121
|
add_compile_definitions(GGML_CUDA_NO_FA)
|
|
118
122
|
endif()
|
|
@@ -44,21 +44,22 @@ if (GGML_METAL_EMBED_LIBRARY)
|
|
|
44
44
|
set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
|
|
45
45
|
|
|
46
46
|
add_custom_command(
|
|
47
|
-
OUTPUT ${METALLIB_EMBED_ASM}
|
|
47
|
+
OUTPUT "${METALLIB_EMBED_ASM}"
|
|
48
48
|
COMMAND echo "Embedding Metal library"
|
|
49
|
-
COMMAND sed -e
|
|
50
|
-
COMMAND sed -e
|
|
51
|
-
COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM}
|
|
52
|
-
COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM}
|
|
53
|
-
COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM}
|
|
54
|
-
COMMAND echo
|
|
55
|
-
COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM}
|
|
56
|
-
COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM}
|
|
49
|
+
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}"
|
|
50
|
+
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}"
|
|
51
|
+
COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}"
|
|
52
|
+
COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}"
|
|
53
|
+
COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}"
|
|
54
|
+
COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}"
|
|
55
|
+
COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}"
|
|
56
|
+
COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}"
|
|
57
57
|
DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
|
|
58
58
|
COMMENT "Generate assembly for embedded Metal library"
|
|
59
|
+
VERBATIM
|
|
59
60
|
)
|
|
60
61
|
|
|
61
|
-
target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM})
|
|
62
|
+
target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
|
|
62
63
|
else()
|
|
63
64
|
if (GGML_METAL_SHADER_DEBUG)
|
|
64
65
|
# custom command to do the following:
|
|
@@ -498,6 +498,7 @@ enum ggml_metal_kernel_type {
|
|
|
498
498
|
GGML_METAL_KERNEL_TYPE_COS,
|
|
499
499
|
GGML_METAL_KERNEL_TYPE_NEG,
|
|
500
500
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
501
|
+
GGML_METAL_KERNEL_TYPE_MEAN,
|
|
501
502
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
502
503
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
|
503
504
|
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
|
@@ -1454,6 +1455,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
1454
1455
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
1455
1456
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
|
1456
1457
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
1458
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
|
1457
1459
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
1458
1460
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
|
1459
1461
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
|
@@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
1653
1655
|
case GGML_OP_LOG:
|
|
1654
1656
|
return false; // TODO: implement
|
|
1655
1657
|
case GGML_OP_SUM_ROWS:
|
|
1658
|
+
case GGML_OP_MEAN:
|
|
1656
1659
|
case GGML_OP_SOFT_MAX:
|
|
1657
1660
|
case GGML_OP_GROUP_NORM:
|
|
1658
1661
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
|
@@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node(
|
|
|
2400
2403
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2401
2404
|
} break;
|
|
2402
2405
|
case GGML_OP_SUM_ROWS:
|
|
2406
|
+
case GGML_OP_MEAN:
|
|
2403
2407
|
{
|
|
2404
2408
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
|
2405
2409
|
|
|
2406
|
-
id<MTLComputePipelineState> pipeline =
|
|
2410
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
2411
|
+
|
|
2412
|
+
switch (dst->op) {
|
|
2413
|
+
case GGML_OP_SUM_ROWS:
|
|
2414
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
|
2415
|
+
break;
|
|
2416
|
+
case GGML_OP_MEAN:
|
|
2417
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
|
2418
|
+
break;
|
|
2419
|
+
default:
|
|
2420
|
+
GGML_ABORT("fatal error");
|
|
2421
|
+
}
|
|
2422
|
+
|
|
2423
|
+
int nth = 32; // SIMD width
|
|
2424
|
+
|
|
2425
|
+
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
2426
|
+
nth *= 2;
|
|
2427
|
+
}
|
|
2407
2428
|
|
|
2429
|
+
nth = MIN(nth, ne00);
|
|
2408
2430
|
|
|
2409
2431
|
ggml_metal_kargs_sum_rows args = {
|
|
2410
2432
|
/*.ne00 =*/ ne00,
|
|
@@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node(
|
|
|
2434
2456
|
};
|
|
2435
2457
|
|
|
2436
2458
|
[encoder setComputePipelineState:pipeline];
|
|
2437
|
-
[encoder
|
|
2438
|
-
[encoder setBuffer:
|
|
2439
|
-
[encoder
|
|
2459
|
+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
2460
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
2461
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
2462
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
2440
2463
|
|
|
2441
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(
|
|
2464
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2442
2465
|
} break;
|
|
2443
2466
|
case GGML_OP_SOFT_MAX:
|
|
2444
2467
|
{
|
|
@@ -4766,6 +4789,8 @@ static bool ggml_metal_encode_node(
|
|
|
4766
4789
|
GGML_ASSERT(nqptg % 8 == 0);
|
|
4767
4790
|
GGML_ASSERT(ncpsg % 32 == 0);
|
|
4768
4791
|
|
|
4792
|
+
const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
|
|
4793
|
+
|
|
4769
4794
|
// 2*(2*ncpsg + nqptg)*(nsg)
|
|
4770
4795
|
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
|
4771
4796
|
//
|
|
@@ -4773,7 +4798,7 @@ static bool ggml_metal_encode_node(
|
|
|
4773
4798
|
// the shared memory needed for the simdgroups to load the KV cache
|
|
4774
4799
|
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
|
4775
4800
|
//
|
|
4776
|
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
|
4801
|
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
|
|
4777
4802
|
|
|
4778
4803
|
int64_t nsgmax = 2;
|
|
4779
4804
|
|
|
@@ -4810,9 +4835,9 @@ static bool ggml_metal_encode_node(
|
|
|
4810
4835
|
// and store the soft_max values and the mask
|
|
4811
4836
|
//
|
|
4812
4837
|
// ne00*(nsg)
|
|
4813
|
-
// each simdgroup has a full
|
|
4838
|
+
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
|
4814
4839
|
//
|
|
4815
|
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
|
|
4840
|
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
|
|
4816
4841
|
|
|
4817
4842
|
int64_t nsgmax = 2;
|
|
4818
4843
|
while (true) {
|