@fugood/llama.node 1.3.1 → 1.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +4 -3
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +6 -6
- package/src/llama.cpp/CMakeLists.txt +4 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -37
- package/src/llama.cpp/common/arg.cpp +7 -0
- package/src/llama.cpp/common/common.cpp +1 -5
- package/src/llama.cpp/common/common.h +2 -1
- package/src/llama.cpp/common/download.cpp +47 -29
- package/src/llama.cpp/common/log.cpp +6 -0
- package/src/llama.cpp/common/log.h +2 -0
- package/src/llama.cpp/ggml/include/ggml.h +71 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +34 -11
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +50 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +283 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +235 -34
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +289 -317
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -4
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +95 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +10 -0
- package/src/llama.cpp/src/CMakeLists.txt +6 -0
- package/src/llama.cpp/src/llama-arch.cpp +32 -0
- package/src/llama.cpp/src/llama-arch.h +2 -0
- package/src/llama.cpp/src/llama-graph.cpp +2 -1
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +4 -3
- package/src/llama.cpp/src/llama-model.cpp +102 -0
- package/src/llama.cpp/src/llama-model.h +2 -0
- package/src/llama.cpp/src/llama-sampling.cpp +10 -5
- package/src/llama.cpp/src/llama-vocab.cpp +16 -1
- package/src/llama.cpp/src/llama-vocab.h +1 -0
- package/src/llama.cpp/src/models/afmoe.cpp +187 -0
- package/src/llama.cpp/src/models/ernie4-5.cpp +4 -5
- package/src/llama.cpp/src/models/models.h +4 -0
- package/src/llama.cpp/src/models/openai-moe-iswa.cpp +2 -1
- package/src/llama.cpp/src/unicode.cpp +77 -0
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
// KleidiAI micro-kernels
|
|
6
6
|
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
|
7
|
+
#include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
|
|
7
8
|
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
|
8
9
|
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
|
9
10
|
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
|
@@ -11,20 +12,31 @@
|
|
|
11
12
|
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
|
12
13
|
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
|
13
14
|
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
|
|
15
|
+
#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
|
|
16
|
+
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
|
|
17
|
+
#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
|
|
18
|
+
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
|
|
19
|
+
#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
|
|
20
|
+
#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
|
|
14
21
|
|
|
15
22
|
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
|
16
23
|
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
|
17
24
|
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
|
|
18
25
|
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
|
26
|
+
#include "kai_lhs_quant_pack_qai8dxp_f32.h"
|
|
19
27
|
|
|
20
28
|
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
|
|
21
29
|
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
|
22
30
|
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
|
31
|
+
#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
|
|
23
32
|
|
|
24
33
|
#include "kai_common.h"
|
|
25
34
|
|
|
26
35
|
#include "simd-mappings.h"
|
|
27
36
|
|
|
37
|
+
#define GGML_COMMON_DECL_CPP
|
|
38
|
+
#include "ggml-common.h"
|
|
39
|
+
|
|
28
40
|
#include "kernels.h"
|
|
29
41
|
|
|
30
42
|
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
|
@@ -55,6 +67,14 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
|
|
55
67
|
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
|
56
68
|
}
|
|
57
69
|
|
|
70
|
+
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
|
|
71
|
+
static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
|
72
|
+
const void* lhs, const void* rhs, void* dst,
|
|
73
|
+
size_t dst_stride_row, size_t dst_stride_col,
|
|
74
|
+
float clamp_min, float clamp_max) {
|
|
75
|
+
Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
|
76
|
+
}
|
|
77
|
+
|
|
58
78
|
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
|
59
79
|
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
|
60
80
|
return Fn(m, k, bl, mr, kr, sr);
|
|
@@ -93,6 +113,12 @@ static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t m
|
|
|
93
113
|
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
|
94
114
|
}
|
|
95
115
|
|
|
116
|
+
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
|
|
117
|
+
static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
|
|
118
|
+
size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {
|
|
119
|
+
Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
|
|
120
|
+
}
|
|
121
|
+
|
|
96
122
|
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
|
97
123
|
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
|
|
98
124
|
return Fn(n, k, nr, kr, bl);
|
|
@@ -124,6 +150,18 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
|
|
|
124
150
|
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
|
|
125
151
|
}
|
|
126
152
|
|
|
153
|
+
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
|
|
154
|
+
static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
|
155
|
+
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
|
|
156
|
+
void* rhs_packed, size_t extra_bytes, const void* params) {
|
|
157
|
+
Fn(num_groups, n, k, nr, kr, sr,
|
|
158
|
+
static_cast<const int8_t*>(rhs),
|
|
159
|
+
static_cast<const float*>(bias),
|
|
160
|
+
static_cast<const float*>(scale),
|
|
161
|
+
rhs_packed, extra_bytes,
|
|
162
|
+
static_cast<const kai_rhs_pack_qsi8cx_params*>(params));
|
|
163
|
+
}
|
|
164
|
+
|
|
127
165
|
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
|
|
128
166
|
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
|
129
167
|
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
|
|
@@ -213,6 +251,57 @@ static void dequantize_row_qsi4c32ps1s0scalef16(
|
|
|
213
251
|
GGML_UNUSED(kr);
|
|
214
252
|
}
|
|
215
253
|
|
|
254
|
+
static void dequantize_row_qsi8cxp(
|
|
255
|
+
const void *packed_data,
|
|
256
|
+
int32_t row_idx,
|
|
257
|
+
int64_t k,
|
|
258
|
+
float *out,
|
|
259
|
+
size_t nr,
|
|
260
|
+
size_t packed_row_stride,
|
|
261
|
+
size_t kr,
|
|
262
|
+
size_t bl,
|
|
263
|
+
size_t num_bytes_multiplier
|
|
264
|
+
) {
|
|
265
|
+
GGML_UNUSED(bl);
|
|
266
|
+
GGML_UNUSED(num_bytes_multiplier);
|
|
267
|
+
|
|
268
|
+
const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;
|
|
269
|
+
const size_t group_idx = row_idx / nr;
|
|
270
|
+
const size_t row_in_group = row_idx % nr;
|
|
271
|
+
|
|
272
|
+
const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;
|
|
273
|
+
const int8_t * data_base = reinterpret_cast<const int8_t *>(group_ptr);
|
|
274
|
+
|
|
275
|
+
const size_t num_blocks = k_internal / kr;
|
|
276
|
+
|
|
277
|
+
for (size_t block = 0; block < num_blocks; ++block) {
|
|
278
|
+
const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;
|
|
279
|
+
for (size_t i = 0; i < kr; ++i) {
|
|
280
|
+
const size_t k_idx = block * kr + i;
|
|
281
|
+
if (k_idx < (size_t) k) {
|
|
282
|
+
out[k_idx] = static_cast<float>(block_ptr[i]);
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
const uint8_t * sums_ptr = group_ptr + nr * k_internal;
|
|
288
|
+
GGML_UNUSED(sums_ptr);
|
|
289
|
+
|
|
290
|
+
const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));
|
|
291
|
+
const float scale = scale_ptr[row_in_group];
|
|
292
|
+
|
|
293
|
+
if (scale == 0.0f) {
|
|
294
|
+
for (size_t i = 0; i < (size_t) k; ++i) {
|
|
295
|
+
out[i] = 0.0f;
|
|
296
|
+
}
|
|
297
|
+
return;
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
for (size_t i = 0; i < (size_t) k; ++i) {
|
|
301
|
+
out[i] *= scale;
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
|
|
216
305
|
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
217
306
|
#if defined(__ARM_FEATURE_SME)
|
|
218
307
|
{
|
|
@@ -548,6 +637,174 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
548
637
|
#endif
|
|
549
638
|
};
|
|
550
639
|
|
|
640
|
+
static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
|
641
|
+
#if defined(__ARM_FEATURE_SME)
|
|
642
|
+
{
|
|
643
|
+
/* SME GEMM */
|
|
644
|
+
{
|
|
645
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
646
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
647
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
648
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
649
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
650
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
651
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
652
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
653
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
|
654
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
|
655
|
+
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
|
656
|
+
},
|
|
657
|
+
/* .gemm_lhs_info = */ {
|
|
658
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
659
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
660
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
661
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
662
|
+
},
|
|
663
|
+
/* SME GEMV */
|
|
664
|
+
{
|
|
665
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
666
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
667
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
668
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
669
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
670
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
671
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
672
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
673
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
|
674
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
|
675
|
+
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
|
676
|
+
},
|
|
677
|
+
/* .gemv_lhs_info = */ {
|
|
678
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
679
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
680
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
681
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
682
|
+
},
|
|
683
|
+
/* .rhs_info = */ {
|
|
684
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
|
685
|
+
/* .to_float = */ dequantize_row_qsi8cxp,
|
|
686
|
+
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
687
|
+
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
688
|
+
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
689
|
+
},
|
|
690
|
+
/* .required_cpu = */ CPU_FEATURE_SME,
|
|
691
|
+
/* .lhs_type = */ GGML_TYPE_F32,
|
|
692
|
+
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
|
693
|
+
/* .op_type = */ GGML_TYPE_F32,
|
|
694
|
+
},
|
|
695
|
+
#endif
|
|
696
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
|
697
|
+
{
|
|
698
|
+
/* I8MM GEMM */
|
|
699
|
+
{
|
|
700
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
701
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
702
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
703
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
704
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
705
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
706
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
707
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
708
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
|
709
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
|
710
|
+
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
|
711
|
+
},
|
|
712
|
+
/* .gemm_lhs_info = */ {
|
|
713
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
714
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
715
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
716
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
717
|
+
},
|
|
718
|
+
/* I8MM GEMV (dotprod fallback) */
|
|
719
|
+
{
|
|
720
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
721
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
722
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
723
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
724
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
725
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
726
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
727
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
728
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
|
729
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
|
730
|
+
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
|
731
|
+
},
|
|
732
|
+
/* .gemv_lhs_info = */ {
|
|
733
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
734
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
735
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
736
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
737
|
+
},
|
|
738
|
+
/* .rhs_info = */ {
|
|
739
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
|
740
|
+
/* .to_float = */ dequantize_row_qsi8cxp,
|
|
741
|
+
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
742
|
+
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
743
|
+
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
744
|
+
},
|
|
745
|
+
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
|
746
|
+
/* .lhs_type = */ GGML_TYPE_F32,
|
|
747
|
+
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
|
748
|
+
/* .op_type = */ GGML_TYPE_F32,
|
|
749
|
+
},
|
|
750
|
+
#endif
|
|
751
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
|
752
|
+
{
|
|
753
|
+
/* DOTPROD GEMM */
|
|
754
|
+
{
|
|
755
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
756
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
757
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
758
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
759
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
760
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
761
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
762
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
763
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
|
764
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
|
765
|
+
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
|
766
|
+
},
|
|
767
|
+
/* .gemm_lhs_info = */ {
|
|
768
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
769
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
770
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
771
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
772
|
+
},
|
|
773
|
+
/* DOTPROD GEMV */
|
|
774
|
+
{
|
|
775
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
776
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
777
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
778
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
779
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
780
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
781
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
782
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
783
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
|
784
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
|
785
|
+
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
|
786
|
+
},
|
|
787
|
+
/* .gemv_lhs_info = */ {
|
|
788
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
789
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
790
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
791
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
792
|
+
},
|
|
793
|
+
/* .rhs_info = */ {
|
|
794
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
|
795
|
+
/* .to_float = */ dequantize_row_qsi8cxp,
|
|
796
|
+
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
797
|
+
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
798
|
+
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
799
|
+
},
|
|
800
|
+
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
|
801
|
+
/* .lhs_type = */ GGML_TYPE_F32,
|
|
802
|
+
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
|
803
|
+
/* .op_type = */ GGML_TYPE_F32,
|
|
804
|
+
},
|
|
805
|
+
#endif
|
|
806
|
+
};
|
|
807
|
+
|
|
551
808
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
|
552
809
|
ggml_kleidiai_kernels * kernel = nullptr;
|
|
553
810
|
|
|
@@ -562,6 +819,17 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|
|
562
819
|
break;
|
|
563
820
|
}
|
|
564
821
|
}
|
|
822
|
+
if (!kernel) {
|
|
823
|
+
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
|
824
|
+
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
|
|
825
|
+
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
|
|
826
|
+
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
|
|
827
|
+
gemm_gemv_kernels_q8[i].op_type == tensor->type) {
|
|
828
|
+
kernel = &gemm_gemv_kernels_q8[i];
|
|
829
|
+
break;
|
|
830
|
+
}
|
|
831
|
+
}
|
|
832
|
+
}
|
|
565
833
|
#endif
|
|
566
834
|
}
|
|
567
835
|
|
|
@@ -582,3 +850,18 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
|
|
|
582
850
|
|
|
583
851
|
return kernels;
|
|
584
852
|
}
|
|
853
|
+
|
|
854
|
+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {
|
|
855
|
+
ggml_kleidiai_kernels * kernels = nullptr;
|
|
856
|
+
|
|
857
|
+
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
|
858
|
+
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
|
859
|
+
if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
|
|
860
|
+
kernels = &gemm_gemv_kernels_q8[i];
|
|
861
|
+
break;
|
|
862
|
+
}
|
|
863
|
+
}
|
|
864
|
+
#endif
|
|
865
|
+
|
|
866
|
+
return kernels;
|
|
867
|
+
}
|
|
@@ -87,3 +87,4 @@ struct ggml_kleidiai_kernels {
|
|
|
87
87
|
|
|
88
88
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
|
|
89
89
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
|
|
90
|
+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features);
|