@fugood/llama.node 1.3.1 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. package/CMakeLists.txt +4 -3
  2. package/package.json +14 -14
  3. package/scripts/llama.cpp.patch +6 -6
  4. package/src/llama.cpp/CMakeLists.txt +4 -0
  5. package/src/llama.cpp/common/CMakeLists.txt +6 -37
  6. package/src/llama.cpp/common/arg.cpp +7 -0
  7. package/src/llama.cpp/common/common.cpp +1 -5
  8. package/src/llama.cpp/common/common.h +2 -1
  9. package/src/llama.cpp/common/download.cpp +47 -29
  10. package/src/llama.cpp/common/log.cpp +6 -0
  11. package/src/llama.cpp/common/log.h +2 -0
  12. package/src/llama.cpp/ggml/include/ggml.h +71 -0
  13. package/src/llama.cpp/ggml/src/CMakeLists.txt +16 -0
  14. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +34 -11
  15. package/src/llama.cpp/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  16. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +50 -16
  17. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +283 -0
  18. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +1 -0
  19. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +235 -34
  20. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +289 -317
  21. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -4
  22. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +95 -42
  23. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +16 -0
  24. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +2 -0
  25. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +17 -0
  26. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +10 -0
  27. package/src/llama.cpp/src/CMakeLists.txt +6 -0
  28. package/src/llama.cpp/src/llama-arch.cpp +32 -0
  29. package/src/llama.cpp/src/llama-arch.h +2 -0
  30. package/src/llama.cpp/src/llama-graph.cpp +2 -1
  31. package/src/llama.cpp/src/llama-memory-recurrent.cpp +4 -3
  32. package/src/llama.cpp/src/llama-model.cpp +102 -0
  33. package/src/llama.cpp/src/llama-model.h +2 -0
  34. package/src/llama.cpp/src/llama-sampling.cpp +10 -5
  35. package/src/llama.cpp/src/llama-vocab.cpp +16 -1
  36. package/src/llama.cpp/src/llama-vocab.h +1 -0
  37. package/src/llama.cpp/src/models/afmoe.cpp +187 -0
  38. package/src/llama.cpp/src/models/ernie4-5.cpp +4 -5
  39. package/src/llama.cpp/src/models/models.h +4 -0
  40. package/src/llama.cpp/src/models/openai-moe-iswa.cpp +2 -1
  41. package/src/llama.cpp/src/unicode.cpp +77 -0
@@ -4,6 +4,7 @@
4
4
 
5
5
  // KleidiAI micro-kernels
6
6
  #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
7
+ #include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
7
8
  #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
8
9
  #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
9
10
  #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
@@ -11,20 +12,31 @@
11
12
  #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
12
13
  #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
13
14
  #include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
15
+ #include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
16
+ #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
17
+ #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
18
+ #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
19
+ #include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
20
+ #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
14
21
 
15
22
  #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
16
23
  #include "kai_lhs_quant_pack_qsi8d32p_f32.h"
17
24
  #include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
18
25
  #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
26
+ #include "kai_lhs_quant_pack_qai8dxp_f32.h"
19
27
 
20
28
  #include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
21
29
  #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
22
30
  #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
31
+ #include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
23
32
 
24
33
  #include "kai_common.h"
25
34
 
26
35
  #include "simd-mappings.h"
27
36
 
37
+ #define GGML_COMMON_DECL_CPP
38
+ #include "ggml-common.h"
39
+
28
40
  #include "kernels.h"
29
41
 
30
42
  #define NELEMS(x) sizeof(x) / sizeof(*x)
@@ -55,6 +67,14 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
55
67
  Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
56
68
  }
57
69
 
70
+ template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
71
+ static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
72
+ const void* lhs, const void* rhs, void* dst,
73
+ size_t dst_stride_row, size_t dst_stride_col,
74
+ float clamp_min, float clamp_max) {
75
+ Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
76
+ }
77
+
58
78
  template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
59
79
  static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
60
80
  return Fn(m, k, bl, mr, kr, sr);
@@ -93,6 +113,12 @@ static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t m
93
113
  Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
94
114
  }
95
115
 
116
+ template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
117
+ static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
118
+ size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {
119
+ Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
120
+ }
121
+
96
122
  template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
97
123
  static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
98
124
  return Fn(n, k, nr, kr, bl);
@@ -124,6 +150,18 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
124
150
  static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
125
151
  }
126
152
 
153
+ template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
154
+ static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
155
+ size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
156
+ void* rhs_packed, size_t extra_bytes, const void* params) {
157
+ Fn(num_groups, n, k, nr, kr, sr,
158
+ static_cast<const int8_t*>(rhs),
159
+ static_cast<const float*>(bias),
160
+ static_cast<const float*>(scale),
161
+ rhs_packed, extra_bytes,
162
+ static_cast<const kai_rhs_pack_qsi8cx_params*>(params));
163
+ }
164
+
127
165
  template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
128
166
  static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
129
167
  size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
@@ -213,6 +251,57 @@ static void dequantize_row_qsi4c32ps1s0scalef16(
213
251
  GGML_UNUSED(kr);
214
252
  }
215
253
 
254
+ static void dequantize_row_qsi8cxp(
255
+ const void *packed_data,
256
+ int32_t row_idx,
257
+ int64_t k,
258
+ float *out,
259
+ size_t nr,
260
+ size_t packed_row_stride,
261
+ size_t kr,
262
+ size_t bl,
263
+ size_t num_bytes_multiplier
264
+ ) {
265
+ GGML_UNUSED(bl);
266
+ GGML_UNUSED(num_bytes_multiplier);
267
+
268
+ const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;
269
+ const size_t group_idx = row_idx / nr;
270
+ const size_t row_in_group = row_idx % nr;
271
+
272
+ const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;
273
+ const int8_t * data_base = reinterpret_cast<const int8_t *>(group_ptr);
274
+
275
+ const size_t num_blocks = k_internal / kr;
276
+
277
+ for (size_t block = 0; block < num_blocks; ++block) {
278
+ const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;
279
+ for (size_t i = 0; i < kr; ++i) {
280
+ const size_t k_idx = block * kr + i;
281
+ if (k_idx < (size_t) k) {
282
+ out[k_idx] = static_cast<float>(block_ptr[i]);
283
+ }
284
+ }
285
+ }
286
+
287
+ const uint8_t * sums_ptr = group_ptr + nr * k_internal;
288
+ GGML_UNUSED(sums_ptr);
289
+
290
+ const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));
291
+ const float scale = scale_ptr[row_in_group];
292
+
293
+ if (scale == 0.0f) {
294
+ for (size_t i = 0; i < (size_t) k; ++i) {
295
+ out[i] = 0.0f;
296
+ }
297
+ return;
298
+ }
299
+
300
+ for (size_t i = 0; i < (size_t) k; ++i) {
301
+ out[i] *= scale;
302
+ }
303
+ }
304
+
216
305
  static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
217
306
  #if defined(__ARM_FEATURE_SME)
218
307
  {
@@ -548,6 +637,174 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
548
637
  #endif
549
638
  };
550
639
 
640
+ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
641
+ #if defined(__ARM_FEATURE_SME)
642
+ {
643
+ /* SME GEMM */
644
+ {
645
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
646
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
647
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
648
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
649
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
650
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
651
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
652
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
653
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
654
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
655
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
656
+ },
657
+ /* .gemm_lhs_info = */ {
658
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
659
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
660
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
661
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
662
+ },
663
+ /* SME GEMV */
664
+ {
665
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
666
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
667
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
668
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
669
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
670
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
671
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
672
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
673
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
674
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
675
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
676
+ },
677
+ /* .gemv_lhs_info = */ {
678
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
679
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
680
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
681
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
682
+ },
683
+ /* .rhs_info = */ {
684
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
685
+ /* .to_float = */ dequantize_row_qsi8cxp,
686
+ /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
687
+ /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
688
+ /* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
689
+ },
690
+ /* .required_cpu = */ CPU_FEATURE_SME,
691
+ /* .lhs_type = */ GGML_TYPE_F32,
692
+ /* .rhs_type = */ GGML_TYPE_Q8_0,
693
+ /* .op_type = */ GGML_TYPE_F32,
694
+ },
695
+ #endif
696
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
697
+ {
698
+ /* I8MM GEMM */
699
+ {
700
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
701
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
702
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
703
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
704
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
705
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
706
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
707
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
708
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
709
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
710
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
711
+ },
712
+ /* .gemm_lhs_info = */ {
713
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
714
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
715
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
716
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
717
+ },
718
+ /* I8MM GEMV (dotprod fallback) */
719
+ {
720
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
721
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
722
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
723
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
724
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
725
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
726
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
727
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
728
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
729
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
730
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
731
+ },
732
+ /* .gemv_lhs_info = */ {
733
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
734
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
735
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
736
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
737
+ },
738
+ /* .rhs_info = */ {
739
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
740
+ /* .to_float = */ dequantize_row_qsi8cxp,
741
+ /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
742
+ /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
743
+ /* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
744
+ },
745
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
746
+ /* .lhs_type = */ GGML_TYPE_F32,
747
+ /* .rhs_type = */ GGML_TYPE_Q8_0,
748
+ /* .op_type = */ GGML_TYPE_F32,
749
+ },
750
+ #endif
751
+ #if defined(__ARM_FEATURE_DOTPROD)
752
+ {
753
+ /* DOTPROD GEMM */
754
+ {
755
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
756
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
757
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
758
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
759
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
760
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
761
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
762
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
763
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
764
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
765
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
766
+ },
767
+ /* .gemm_lhs_info = */ {
768
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
769
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
770
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
771
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
772
+ },
773
+ /* DOTPROD GEMV */
774
+ {
775
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
776
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
777
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
778
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
779
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
780
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
781
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
782
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
783
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
784
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
785
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
786
+ },
787
+ /* .gemv_lhs_info = */ {
788
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
789
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
790
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
791
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
792
+ },
793
+ /* .rhs_info = */ {
794
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
795
+ /* .to_float = */ dequantize_row_qsi8cxp,
796
+ /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
797
+ /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
798
+ /* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
799
+ },
800
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD,
801
+ /* .lhs_type = */ GGML_TYPE_F32,
802
+ /* .rhs_type = */ GGML_TYPE_Q8_0,
803
+ /* .op_type = */ GGML_TYPE_F32,
804
+ },
805
+ #endif
806
+ };
807
+
551
808
  ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
552
809
  ggml_kleidiai_kernels * kernel = nullptr;
553
810
 
@@ -562,6 +819,17 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
562
819
  break;
563
820
  }
564
821
  }
822
+ if (!kernel) {
823
+ for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
824
+ if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
825
+ gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
826
+ gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
827
+ gemm_gemv_kernels_q8[i].op_type == tensor->type) {
828
+ kernel = &gemm_gemv_kernels_q8[i];
829
+ break;
830
+ }
831
+ }
832
+ }
565
833
  #endif
566
834
  }
567
835
 
@@ -582,3 +850,18 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
582
850
 
583
851
  return kernels;
584
852
  }
853
+
854
+ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {
855
+ ggml_kleidiai_kernels * kernels = nullptr;
856
+
857
+ #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
858
+ for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
859
+ if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
860
+ kernels = &gemm_gemv_kernels_q8[i];
861
+ break;
862
+ }
863
+ }
864
+ #endif
865
+
866
+ return kernels;
867
+ }
@@ -87,3 +87,4 @@ struct ggml_kleidiai_kernels {
87
87
 
88
88
  ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
89
89
  ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
90
+ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features);