@novastera-oss/llamarn 0.3.0 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/build.gradle +2 -1
- package/android/proguard-rules.pro +12 -0
- package/android/src/main/cpp/include/llama.h +15 -47
- package/android/src/main/jniLibs/arm64-v8a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libggml.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/libllama.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-base.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libggml.so +0 -0
- package/android/src/main/jniLibs/armeabi-v7a/libllama.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86/libggml.so +0 -0
- package/android/src/main/jniLibs/x86/libllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-base.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml-cpu.so +0 -0
- package/android/src/main/jniLibs/x86_64/libggml.so +0 -0
- package/android/src/main/jniLibs/x86_64/libllama.so +0 -0
- package/cpp/build-info.cpp +2 -2
- package/cpp/llama.cpp/CMakePresets.json +11 -0
- package/cpp/llama.cpp/CODEOWNERS +1 -0
- package/cpp/llama.cpp/README.md +4 -3
- package/cpp/llama.cpp/common/arg.cpp +45 -1
- package/cpp/llama.cpp/common/common.cpp +22 -6
- package/cpp/llama.cpp/common/common.h +18 -4
- package/cpp/llama.cpp/convert_hf_to_gguf.py +500 -32
- package/cpp/llama.cpp/convert_hf_to_gguf_update.py +12 -13
- package/cpp/llama.cpp/ggml/CMakeLists.txt +6 -1
- package/cpp/llama.cpp/ggml/cmake/ggml-config.cmake.in +85 -47
- package/cpp/llama.cpp/ggml/include/ggml-webgpu.h +19 -0
- package/cpp/llama.cpp/ggml/src/CMakeLists.txt +1 -0
- package/cpp/llama.cpp/ggml/src/ggml-alloc.c +0 -15
- package/cpp/llama.cpp/ggml/src/ggml-backend-reg.cpp +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-backend.cpp +8 -20
- package/cpp/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +58 -3
- package/cpp/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +130 -22
- package/cpp/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +122 -16
- package/cpp/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +5 -2
- package/cpp/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +109 -12
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +88 -10
- package/cpp/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +343 -1094
- package/cpp/llama.cpp/ggml/src/ggml-cpu/ops.cpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cpu/repack.cpp +0 -1
- package/cpp/llama.cpp/ggml/src/ggml-cpu/vec.cpp +3 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +3 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/common.cuh +14 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/convert.cu +64 -17
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh +225 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/cpy.cu +41 -301
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh +85 -67
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh +45 -62
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f16.cu +28 -43
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-tile-f32.cu +41 -56
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f16.cuh +36 -47
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-vec-f32.cuh +31 -43
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu +22 -37
- package/cpp/llama.cpp/ggml/src/ggml-cuda/fattn.cu +3 -13
- package/cpp/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu +73 -23
- package/cpp/llama.cpp/ggml/src/ggml-cuda/im2col.cu +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mma.cuh +111 -3
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cu +6 -4
- package/cpp/llama.cpp/ggml/src/ggml-cuda/mmq.cuh +1152 -689
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cu +92 -5
- package/cpp/llama.cpp/ggml/src/ggml-cuda/norm.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cu +275 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cu +7 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/unary.cuh +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +13 -1
- package/cpp/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +2 -2
- package/cpp/llama.cpp/ggml/src/ggml-impl.h +16 -0
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +13 -3
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.m +407 -69
- package/cpp/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal +380 -83
- package/cpp/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +18 -4
- package/cpp/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +295 -2
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
- package/cpp/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +4 -4
- package/cpp/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +14 -26
- package/cpp/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +131 -46
- package/cpp/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-sycl/quants.hpp +8 -9
- package/cpp/llama.cpp/ggml/src/ggml-sycl/set_rows.cpp +43 -43
- package/cpp/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +287 -22
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +265 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +1 -5
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +2 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +3 -8
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +8 -2
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
- package/cpp/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +71 -16
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/ggml-webgpu.cpp +907 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +35 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
- package/cpp/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +56 -0
- package/cpp/llama.cpp/ggml/src/ggml.c +4 -6
- package/cpp/llama.cpp/gguf-py/gguf/constants.py +98 -0
- package/cpp/llama.cpp/gguf-py/gguf/metadata.py +4 -0
- package/cpp/llama.cpp/gguf-py/gguf/scripts/gguf_dump.py +24 -1
- package/cpp/llama.cpp/gguf-py/gguf/tensor_mapping.py +75 -52
- package/cpp/llama.cpp/include/llama.h +15 -7
- package/cpp/llama.cpp/models/templates/llama-cpp-rwkv-world.jinja +34 -0
- package/cpp/llama.cpp/models/templates/moonshotai-Kimi-K2.jinja +43 -0
- package/cpp/llama.cpp/requirements/requirements-all.txt +1 -0
- package/cpp/llama.cpp/requirements/requirements-server-bench.txt +5 -0
- package/cpp/llama.cpp/src/llama-arch.cpp +106 -0
- package/cpp/llama.cpp/src/llama-arch.h +5 -0
- package/cpp/llama.cpp/src/llama-batch.cpp +76 -70
- package/cpp/llama.cpp/src/llama-batch.h +24 -18
- package/cpp/llama.cpp/src/llama-chat.cpp +43 -1
- package/cpp/llama.cpp/src/llama-chat.h +2 -0
- package/cpp/llama.cpp/src/llama-context.cpp +180 -106
- package/cpp/llama.cpp/src/llama-context.h +26 -16
- package/cpp/llama.cpp/src/llama-cparams.h +3 -2
- package/cpp/llama.cpp/src/llama-graph.cpp +203 -39
- package/cpp/llama.cpp/src/llama-graph.h +147 -72
- package/cpp/llama.cpp/src/llama-hparams.cpp +40 -0
- package/cpp/llama.cpp/src/llama-hparams.h +10 -2
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +11 -5
- package/cpp/llama.cpp/src/llama-kv-cache-unified-iswa.h +3 -0
- package/cpp/llama.cpp/src/llama-kv-cache-unified.cpp +698 -302
- package/cpp/llama.cpp/src/llama-kv-cache-unified.h +89 -31
- package/cpp/llama.cpp/src/llama-memory-hybrid.cpp +1 -0
- package/cpp/llama.cpp/src/llama-memory-recurrent.cpp +16 -1
- package/cpp/llama.cpp/src/llama-model.cpp +1293 -312
- package/cpp/llama.cpp/src/llama-model.h +3 -4
- package/cpp/llama.cpp/src/llama-quant.cpp +1 -2
- package/cpp/llama.cpp/src/llama-vocab.cpp +363 -8
- package/cpp/llama.cpp/src/llama-vocab.h +2 -0
- package/cpp/llama.cpp/src/unicode.cpp +207 -0
- package/cpp/llama.cpp/src/unicode.h +2 -0
- package/ios/include/common.h +18 -4
- package/ios/include/llama.h +15 -7
- package/ios/libs/llama.xcframework/Info.plist +15 -15
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/ios-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/ios-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4016 -3891
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/A/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/Versions/Current/llama +0 -0
- package/ios/libs/llama.xcframework/macos-arm64_x86_64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5267 -5059
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/tvos-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5238 -5030
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4014 -3889
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/tvos-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5303 -5095
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/xros-arm64/llama.framework/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/DWARF/llama +0 -0
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/aarch64/llama.yml +5274 -5066
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/dSYMs/llama.dSYM/Contents/Resources/Relocations/x86_64/llama.yml +4044 -3919
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/Headers/llama.h +15 -7
- package/ios/libs/llama.xcframework/xros-arm64_x86_64-simulator/llama.framework/llama +0 -0
- package/package.json +4 -4
|
@@ -71,12 +71,15 @@ struct rhs_packing_info {
|
|
|
71
71
|
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
|
|
72
72
|
std::function<size_t(size_t n, size_t k)>
|
|
73
73
|
> packed_size;
|
|
74
|
+
size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
|
|
74
75
|
std::variant<
|
|
75
76
|
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
|
76
77
|
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
|
|
77
78
|
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
|
|
78
79
|
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
|
|
79
80
|
> pack_func;
|
|
81
|
+
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
|
|
82
|
+
size_t kr, size_t bl, size_t num_bytes_multiplier);
|
|
80
83
|
};
|
|
81
84
|
|
|
82
85
|
struct ggml_kleidiai_kernels {
|
|
@@ -40,6 +40,17 @@ struct ggml_kleidiai_context {
|
|
|
40
40
|
ggml_kleidiai_kernels * kernels;
|
|
41
41
|
} static ctx = { CPU_FEATURE_NONE, NULL };
|
|
42
42
|
|
|
43
|
+
static const char* cpu_feature_to_string(cpu_feature f) {
|
|
44
|
+
switch (f) {
|
|
45
|
+
case CPU_FEATURE_NONE: return "NONE";
|
|
46
|
+
case CPU_FEATURE_DOTPROD: return "DOTPROD";
|
|
47
|
+
case CPU_FEATURE_I8MM: return "I8MM";
|
|
48
|
+
case CPU_FEATURE_SVE: return "SVE";
|
|
49
|
+
case CPU_FEATURE_SME: return "SME";
|
|
50
|
+
default: return "UNKNOWN";
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
|
|
43
54
|
static void init_kleidiai_context(void) {
|
|
44
55
|
|
|
45
56
|
ggml_critical_section_start();
|
|
@@ -62,6 +73,11 @@ static void init_kleidiai_context(void) {
|
|
|
62
73
|
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
|
63
74
|
}
|
|
64
75
|
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
|
76
|
+
#ifndef NDEBUG
|
|
77
|
+
if (ctx.kernels) {
|
|
78
|
+
GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
|
|
79
|
+
}
|
|
80
|
+
#endif
|
|
65
81
|
}
|
|
66
82
|
ggml_critical_section_end();
|
|
67
83
|
}
|
|
@@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
|
|
|
102
118
|
|
|
103
119
|
class tensor_traits : public ggml::cpu::tensor_traits {
|
|
104
120
|
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
|
121
|
+
if (op->op != GGML_OP_MUL_MAT) {
|
|
122
|
+
return false;
|
|
123
|
+
}
|
|
105
124
|
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
|
106
125
|
GGML_ASSERT(kernels);
|
|
107
126
|
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
|
@@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
135
154
|
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
|
136
155
|
return compute_forward_kv_cache(params, dst);
|
|
137
156
|
}
|
|
157
|
+
} else if (dst->op == GGML_OP_GET_ROWS) {
|
|
158
|
+
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
|
159
|
+
return compute_forward_get_rows(params, dst);
|
|
160
|
+
}
|
|
138
161
|
}
|
|
139
162
|
return false;
|
|
140
163
|
}
|
|
@@ -270,6 +293,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
270
293
|
}
|
|
271
294
|
|
|
272
295
|
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
296
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
|
297
|
+
|
|
273
298
|
const ggml_tensor * src0 = dst->src[0];
|
|
274
299
|
const ggml_tensor * src1 = dst->src[1];
|
|
275
300
|
|
|
@@ -342,8 +367,49 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
342
367
|
return true;
|
|
343
368
|
}
|
|
344
369
|
|
|
370
|
+
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
371
|
+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
|
372
|
+
GGML_ASSERT(ctx.kernels);
|
|
373
|
+
|
|
374
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
375
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
376
|
+
|
|
377
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
378
|
+
|
|
379
|
+
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
|
|
380
|
+
kernel_info * kernel = &ctx.kernels->gemm;
|
|
381
|
+
|
|
382
|
+
const int64_t nc = ne00;
|
|
383
|
+
const int64_t nr = ggml_nelements(src1);
|
|
384
|
+
|
|
385
|
+
const size_t block_rows = kernel->get_nr();
|
|
386
|
+
const size_t kr = kernel->get_kr();
|
|
387
|
+
|
|
388
|
+
const size_t num_bytes_multiplier = sizeof(uint16_t);
|
|
389
|
+
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
|
|
390
|
+
|
|
391
|
+
const int ith = params->ith;
|
|
392
|
+
const int nth = params->nth;
|
|
393
|
+
|
|
394
|
+
const int dr = (nr + nth - 1) / nth;
|
|
395
|
+
const int ir0 = dr * ith;
|
|
396
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
397
|
+
|
|
398
|
+
for (int64_t i = ir0; i < ir1; ++i) {
|
|
399
|
+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
|
400
|
+
int64_t row_idx = ((const int32_t *)src1->data)[i];
|
|
401
|
+
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
|
|
402
|
+
|
|
403
|
+
float *out = (float *)((char *)dst->data + i * nb1);
|
|
404
|
+
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
return true;
|
|
408
|
+
}
|
|
409
|
+
|
|
345
410
|
public:
|
|
346
411
|
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
|
412
|
+
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
|
347
413
|
GGML_ASSERT(ctx.kernels);
|
|
348
414
|
const size_t n = tensor->ne[1];
|
|
349
415
|
const size_t k = tensor->ne[0];
|
|
@@ -351,17 +417,12 @@ public:
|
|
|
351
417
|
size_t kr = ctx.kernels->gemm.get_kr();
|
|
352
418
|
size_t sr = ctx.kernels->gemm.get_sr();
|
|
353
419
|
|
|
354
|
-
#ifndef NDEBUG
|
|
355
|
-
const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
|
356
|
-
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
|
|
357
|
-
#endif
|
|
358
420
|
struct kai_rhs_pack_qs4cxs1s0_param params;
|
|
359
421
|
params.lhs_zero_point = 1;
|
|
360
422
|
params.rhs_zero_point = 8;
|
|
361
423
|
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
|
|
362
424
|
|
|
363
425
|
return 0;
|
|
364
|
-
|
|
365
426
|
GGML_UNUSED(data_size);
|
|
366
427
|
}
|
|
367
428
|
};
|
|
@@ -375,8 +436,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
|
|
|
375
436
|
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
|
376
437
|
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
|
377
438
|
|
|
378
|
-
GGML_UNUSED(buffer);
|
|
379
439
|
return GGML_STATUS_SUCCESS;
|
|
440
|
+
GGML_UNUSED(buffer);
|
|
380
441
|
}
|
|
381
442
|
|
|
382
443
|
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
|
@@ -418,18 +479,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
|
|
|
418
479
|
GGML_UNUSED(buft);
|
|
419
480
|
}
|
|
420
481
|
|
|
482
|
+
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
|
483
|
+
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
|
484
|
+
GGML_ASSERT(ctx.kernels);
|
|
485
|
+
|
|
486
|
+
const size_t n = tensor->ne[1];
|
|
487
|
+
const size_t k = tensor->ne[0];
|
|
488
|
+
const size_t nr = ctx.kernels->gemm.get_nr();
|
|
489
|
+
const size_t kr = ctx.kernels->gemm.get_kr();
|
|
490
|
+
|
|
491
|
+
return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
|
492
|
+
|
|
493
|
+
GGML_UNUSED(buft);
|
|
494
|
+
}
|
|
495
|
+
|
|
421
496
|
namespace ggml::cpu::kleidiai {
|
|
422
497
|
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
423
498
|
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
|
424
|
-
if (op->op == GGML_OP_MUL_MAT &&
|
|
499
|
+
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
|
|
425
500
|
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
|
426
501
|
op->src[0]->buffer &&
|
|
427
502
|
(ggml_n_dims(op->src[0]) == 2) &&
|
|
428
503
|
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
|
504
|
+
if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
|
|
505
|
+
return false;
|
|
506
|
+
}
|
|
429
507
|
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
430
508
|
return false;
|
|
431
509
|
}
|
|
432
|
-
if (op->src[1]->type == GGML_TYPE_F32 &&
|
|
510
|
+
if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
|
|
433
511
|
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
|
|
434
512
|
return true;
|
|
435
513
|
}
|
|
@@ -438,7 +516,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
|
438
516
|
}
|
|
439
517
|
|
|
440
518
|
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
|
441
|
-
if (op->op == GGML_OP_MUL_MAT) {
|
|
519
|
+
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
|
|
442
520
|
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
|
443
521
|
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
|
444
522
|
}
|
|
@@ -469,7 +547,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
|
|
|
469
547
|
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
|
|
470
548
|
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
|
|
471
549
|
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
|
472
|
-
/* .get_alloc_size = */
|
|
550
|
+
/* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
|
|
473
551
|
/* .is_host = */ nullptr,
|
|
474
552
|
},
|
|
475
553
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|