@fugood/llama.node 1.2.2 → 1.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +33 -11
- package/src/llama.cpp/CMakeLists.txt +1 -0
- package/src/llama.cpp/common/CMakeLists.txt +46 -2
- package/src/llama.cpp/common/arg.cpp +423 -186
- package/src/llama.cpp/common/arg.h +0 -1
- package/src/llama.cpp/common/chat-parser.cpp +154 -13
- package/src/llama.cpp/common/chat-parser.h +3 -0
- package/src/llama.cpp/common/chat.cpp +217 -6
- package/src/llama.cpp/common/chat.h +5 -3
- package/src/llama.cpp/common/common.cpp +23 -6
- package/src/llama.cpp/common/common.h +6 -4
- package/src/llama.cpp/common/http.h +73 -0
- package/src/llama.cpp/common/sampling.cpp +1 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +7 -6
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -1
- package/src/llama.cpp/ggml/include/ggml-rpc.h +8 -9
- package/src/llama.cpp/ggml/include/ggml-zdnn.h +3 -0
- package/src/llama.cpp/ggml/include/ggml.h +22 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +3 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +12 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/loongarch/quants.c +12 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +100 -3
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +0 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +18 -3
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +209 -96
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +32 -44
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +107 -83
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +27 -19
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +8 -8
- package/src/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +103 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +66 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +6 -5
- package/src/llama.cpp/include/llama.h +23 -11
- package/src/llama.cpp/src/llama-arch.cpp +93 -0
- package/src/llama.cpp/src/llama-arch.h +22 -0
- package/src/llama.cpp/src/llama-chat.cpp +1 -1
- package/src/llama.cpp/src/llama-context.cpp +157 -0
- package/src/llama.cpp/src/llama-context.h +10 -0
- package/src/llama.cpp/src/llama-graph.cpp +57 -22
- package/src/llama.cpp/src/llama-graph.h +10 -1
- package/src/llama.cpp/src/llama-hparams.h +17 -2
- package/src/llama.cpp/src/llama-kv-cache-iswa.cpp +10 -2
- package/src/llama.cpp/src/llama-kv-cache-iswa.h +2 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +10 -5
- package/src/llama.cpp/src/llama-kv-cache.h +2 -0
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +19 -9
- package/src/llama.cpp/src/llama-memory-hybrid.h +2 -0
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +19 -3
- package/src/llama.cpp/src/llama-memory-recurrent.h +3 -0
- package/src/llama.cpp/src/llama-memory.h +3 -0
- package/src/llama.cpp/src/llama-model-loader.cpp +2 -0
- package/src/llama.cpp/src/llama-model.cpp +582 -45
- package/src/llama.cpp/src/llama-model.h +23 -1
- package/src/llama.cpp/src/llama-sampling.cpp +5 -0
- package/src/llama.cpp/src/llama-vocab.cpp +7 -1
- package/src/llama.cpp/src/llama-vocab.h +41 -40
- package/src/llama.cpp/src/unicode.h +43 -0
|
@@ -29,6 +29,108 @@
|
|
|
29
29
|
|
|
30
30
|
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
|
31
31
|
|
|
32
|
+
template<size_t(*Fn)(size_t,size_t,size_t)>
|
|
33
|
+
static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
|
|
34
|
+
return Fn(a, b, c);
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
template<size_t(*Fn)(size_t,size_t)>
|
|
38
|
+
static inline size_t kernel_offs_fn2(size_t a, size_t b, size_t) {
|
|
39
|
+
return Fn(a, b);
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
template<void(*Fn)(size_t,size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
|
|
43
|
+
static inline void kernel_run_fn11(size_t m, size_t n, size_t k, size_t bl,
|
|
44
|
+
const void* lhs, const void* rhs, void* dst,
|
|
45
|
+
size_t dst_stride_row, size_t dst_stride_col,
|
|
46
|
+
float clamp_min, float clamp_max) {
|
|
47
|
+
Fn(m, n, k, bl, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,void*,size_t,size_t,float,float)>
|
|
51
|
+
static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
|
52
|
+
const void* lhs, const void* rhs, void* dst,
|
|
53
|
+
size_t dst_stride_row, size_t dst_stride_col,
|
|
54
|
+
float clamp_min, float clamp_max) {
|
|
55
|
+
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
|
59
|
+
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
|
60
|
+
return Fn(m, k, bl, mr, kr, sr);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
|
64
|
+
static inline size_t lhs_ps_fn5(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {
|
|
65
|
+
return Fn(m, k, mr, kr, sr);
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
|
69
|
+
static inline size_t lhs_offs_fn6(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
|
70
|
+
return Fn(m_idx, k, bl, mr, kr, sr);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
|
74
|
+
static inline size_t lhs_offs_fn5(size_t m_idx, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {
|
|
75
|
+
return Fn(m_idx, k, mr, kr, sr);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
|
|
79
|
+
static inline void lhs_pack_float_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
|
|
80
|
+
size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
|
|
81
|
+
Fn(m, k, bl, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>
|
|
85
|
+
static inline void lhs_pack_void_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
|
|
86
|
+
size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
|
|
87
|
+
Fn(m, k, bl, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>
|
|
91
|
+
static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
|
|
92
|
+
size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
|
|
93
|
+
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
|
97
|
+
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
|
|
98
|
+
return Fn(n, k, nr, kr, bl);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
template<size_t(*Fn)(size_t,size_t)>
|
|
102
|
+
static inline size_t rhs_ps_fn2(size_t n, size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {
|
|
103
|
+
return Fn(n, k);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
template<size_t(*Fn)(size_t,size_t,size_t,size_t)>
|
|
107
|
+
static inline size_t rhs_stride_fn4(size_t k, size_t nr, size_t kr, size_t bl) {
|
|
108
|
+
return Fn(k, nr, kr, bl);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
template<size_t(*Fn)(size_t)>
|
|
112
|
+
static inline size_t rhs_stride_fn1(size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {
|
|
113
|
+
return Fn(k);
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const uint8_t*,const float*,void*,size_t,const struct kai_rhs_pack_qs4cxs1s0_param*)>
|
|
117
|
+
static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,
|
|
118
|
+
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* /*scale*/,
|
|
119
|
+
void* rhs_packed, size_t extra_bytes, const void* params) {
|
|
120
|
+
Fn(num_groups, n, k, nr, kr, sr, bl,
|
|
121
|
+
static_cast<const uint8_t*>(rhs),
|
|
122
|
+
static_cast<const float*>(bias),
|
|
123
|
+
rhs_packed, extra_bytes,
|
|
124
|
+
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
|
|
128
|
+
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
|
129
|
+
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
|
|
130
|
+
void* rhs_packed, size_t extra_bytes, const void* params) {
|
|
131
|
+
Fn(num_groups, n, k, nr, kr, sr, rhs_stride, rhs, bias, scale, rhs_packed, extra_bytes, params);
|
|
132
|
+
}
|
|
133
|
+
|
|
32
134
|
static const size_t INT4_PER_BYTE = 2;
|
|
33
135
|
static const size_t INT4_BITS = 4;
|
|
34
136
|
static const int Q4_0_ZERO_POINT = 8;
|
|
@@ -122,17 +224,18 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
122
224
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
123
225
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
124
226
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
125
|
-
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
126
|
-
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
127
227
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
128
228
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
129
|
-
/* .
|
|
229
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
|
|
230
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
|
|
231
|
+
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
|
|
130
232
|
},
|
|
233
|
+
|
|
131
234
|
/* .gemm_lhs_info = */ {
|
|
132
235
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
133
|
-
/* .
|
|
134
|
-
/* .
|
|
135
|
-
/* .
|
|
236
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
|
|
237
|
+
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
|
|
238
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
|
|
136
239
|
},
|
|
137
240
|
/* SME GEMV */
|
|
138
241
|
/* .kern_info = */ {
|
|
@@ -142,23 +245,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
142
245
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
143
246
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
144
247
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
145
|
-
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
146
|
-
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
147
248
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
148
249
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
149
|
-
/* .
|
|
250
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
|
|
251
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
|
|
252
|
+
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
|
|
150
253
|
},
|
|
151
254
|
/* .gemv_lhs_info = */ {
|
|
152
255
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
153
|
-
/* .
|
|
154
|
-
/* .
|
|
155
|
-
/* .
|
|
256
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
|
|
257
|
+
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
|
|
258
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
|
|
156
259
|
},
|
|
157
260
|
/* .rhs_info = */ {
|
|
158
|
-
/* .
|
|
159
|
-
/* .
|
|
160
|
-
/* .
|
|
161
|
-
/* .
|
|
261
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
|
262
|
+
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
|
|
263
|
+
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
|
|
264
|
+
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
|
|
265
|
+
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
|
|
162
266
|
},
|
|
163
267
|
/* .required_cpu = */ CPU_FEATURE_SME,
|
|
164
268
|
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -174,17 +278,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
174
278
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
175
279
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
176
280
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
177
|
-
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
178
|
-
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
179
281
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
180
282
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
181
|
-
/* .
|
|
283
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
|
|
284
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
|
|
285
|
+
/* .run_kernel_ex = */ &kernel_run_fn10<kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
|
|
182
286
|
},
|
|
183
287
|
/* .gemm_lhs_info = */ {
|
|
184
288
|
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
|
185
|
-
/* .
|
|
186
|
-
/* .
|
|
187
|
-
/* .
|
|
289
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,
|
|
290
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,
|
|
291
|
+
/* .pack_func_ex = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,
|
|
188
292
|
},
|
|
189
293
|
/* SME GEMV */
|
|
190
294
|
/* .kern_info = */ {
|
|
@@ -194,23 +298,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
194
298
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
195
299
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
196
300
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
197
|
-
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
198
|
-
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
199
301
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
200
302
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
201
|
-
/* .
|
|
303
|
+
/* .get_lhs_offset_ex = */ nullptr,
|
|
304
|
+
/* .get_rhs_packed_offset_ex = */ nullptr,
|
|
305
|
+
/* .run_kernel_ex = */ nullptr,
|
|
202
306
|
},
|
|
203
307
|
/* .gemv_lhs_info = */ {
|
|
204
308
|
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
|
205
|
-
/* .
|
|
206
|
-
/* .
|
|
207
|
-
/* .
|
|
309
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,
|
|
310
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,
|
|
311
|
+
/* .pack_func_ex = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,
|
|
208
312
|
},
|
|
209
313
|
/* .rhs_info = */ {
|
|
210
|
-
/* .
|
|
211
|
-
/* .
|
|
212
|
-
/* .
|
|
213
|
-
/* .
|
|
314
|
+
/* .packed_stride = */ nullptr,
|
|
315
|
+
/* .to_float = */ nullptr,
|
|
316
|
+
/* .packed_size_ex = */ &rhs_ps_fn2<kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
|
|
317
|
+
/* .packed_stride_ex = */ &rhs_stride_fn1<kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
|
|
318
|
+
/* .pack_func_ex = */ &rhs_pack_fn13<kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
|
|
214
319
|
},
|
|
215
320
|
/* .required_cpu = */ CPU_FEATURE_SME,
|
|
216
321
|
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -229,17 +334,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
229
334
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
230
335
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
231
336
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
232
|
-
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
233
|
-
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
234
337
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
235
338
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
236
|
-
/* .
|
|
339
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
|
340
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
|
341
|
+
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
|
237
342
|
},
|
|
238
343
|
/* .gemm_lhs_info = */ {
|
|
239
344
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
240
|
-
/* .
|
|
241
|
-
/* .
|
|
242
|
-
/* .
|
|
345
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
|
346
|
+
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
|
347
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
|
243
348
|
},
|
|
244
349
|
/* DOTPROD GEMV */
|
|
245
350
|
/* .kern_info = */ {
|
|
@@ -249,23 +354,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
249
354
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
250
355
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
251
356
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
252
|
-
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
253
|
-
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
254
357
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
255
358
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
256
|
-
/* .
|
|
359
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
|
360
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
|
361
|
+
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
|
257
362
|
},
|
|
258
363
|
/* .gemv_lhs_info = */ {
|
|
259
364
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
260
|
-
/* .
|
|
261
|
-
/* .
|
|
262
|
-
/* .
|
|
365
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
|
366
|
+
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
|
367
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
|
263
368
|
},
|
|
264
369
|
/* .rhs_info = */ {
|
|
265
|
-
/* .
|
|
266
|
-
/* .
|
|
267
|
-
/* .
|
|
268
|
-
/* .
|
|
370
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
371
|
+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
|
372
|
+
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
|
373
|
+
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
|
374
|
+
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
|
269
375
|
},
|
|
270
376
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
|
271
377
|
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -283,17 +389,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
283
389
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
284
390
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
285
391
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
286
|
-
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
287
|
-
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
288
392
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
289
393
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
290
|
-
/* .
|
|
394
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
|
395
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
|
396
|
+
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
|
291
397
|
},
|
|
292
398
|
/* .gemm_lhs_info = */ {
|
|
293
399
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
|
294
|
-
/* .
|
|
295
|
-
/* .
|
|
296
|
-
/* .
|
|
400
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
|
401
|
+
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
|
402
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
|
297
403
|
},
|
|
298
404
|
/* i8mm GEMV */
|
|
299
405
|
/* .kern_info = */ {
|
|
@@ -303,23 +409,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
303
409
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
304
410
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
305
411
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
306
|
-
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
307
|
-
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
308
412
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
309
413
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
310
|
-
/* .
|
|
414
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
|
415
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
|
416
|
+
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
|
311
417
|
},
|
|
312
418
|
/* .gemv_lhs_info = */ {
|
|
313
419
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
314
|
-
/* .
|
|
315
|
-
/* .
|
|
316
|
-
/* .
|
|
420
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
|
421
|
+
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
|
422
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
|
317
423
|
},
|
|
318
424
|
/* .rhs_info = */ {
|
|
319
|
-
/* .
|
|
320
|
-
/* .
|
|
321
|
-
/* .
|
|
322
|
-
/* .
|
|
425
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
426
|
+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
|
427
|
+
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
|
428
|
+
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
|
429
|
+
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
|
323
430
|
},
|
|
324
431
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
|
325
432
|
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -338,17 +445,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
338
445
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
339
446
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
340
447
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
341
|
-
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
342
|
-
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
343
448
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
344
449
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
345
|
-
/* .
|
|
450
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
|
451
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
|
452
|
+
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
|
346
453
|
},
|
|
347
454
|
/* .gemm_lhs_info = */ {
|
|
348
455
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
|
349
|
-
/* .
|
|
350
|
-
/* .
|
|
351
|
-
/* .
|
|
456
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
|
457
|
+
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
|
458
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
|
352
459
|
},
|
|
353
460
|
/* i8mm GEMV */
|
|
354
461
|
/* .kern_info = */ {
|
|
@@ -358,23 +465,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
358
465
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
359
466
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
360
467
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
361
|
-
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
362
|
-
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
363
468
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
364
469
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
365
|
-
/* .
|
|
470
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
|
471
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
|
472
|
+
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
|
366
473
|
},
|
|
367
474
|
/* .gemv_lhs_info = */ {
|
|
368
475
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
369
|
-
/* .
|
|
370
|
-
/* .
|
|
371
|
-
/* .
|
|
476
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
|
477
|
+
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
|
478
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
|
372
479
|
},
|
|
373
480
|
/* .rhs_info = */ {
|
|
374
|
-
/* .
|
|
375
|
-
/* .
|
|
376
|
-
/* .
|
|
377
|
-
/* .
|
|
481
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
482
|
+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
|
483
|
+
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
|
484
|
+
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
|
485
|
+
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
|
378
486
|
},
|
|
379
487
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
|
380
488
|
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -392,17 +500,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
392
500
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
393
501
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
394
502
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
395
|
-
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
396
|
-
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
397
503
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
398
504
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
399
|
-
/* .
|
|
505
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
|
506
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
|
507
|
+
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
|
400
508
|
},
|
|
401
509
|
/* .gemm_lhs_info = */ {
|
|
402
510
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
403
|
-
/* .
|
|
404
|
-
/* .
|
|
405
|
-
/* .
|
|
511
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
|
512
|
+
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
|
513
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
|
406
514
|
},
|
|
407
515
|
/* DOTPROD GEMV */
|
|
408
516
|
/* .kern_info = */ {
|
|
@@ -412,23 +520,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
412
520
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
413
521
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
414
522
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
415
|
-
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
416
|
-
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
417
523
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
418
524
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
419
|
-
/* .
|
|
525
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
|
526
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
|
527
|
+
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
|
420
528
|
},
|
|
421
529
|
/* .gemv_lhs_info = */ {
|
|
422
530
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
423
|
-
/* .
|
|
424
|
-
/* .
|
|
425
|
-
/* .
|
|
531
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
|
532
|
+
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
|
533
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
|
426
534
|
},
|
|
427
535
|
/* .rhs_info = */ {
|
|
428
|
-
/* .
|
|
429
|
-
/* .
|
|
430
|
-
/* .
|
|
431
|
-
/* .
|
|
536
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
537
|
+
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
|
538
|
+
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
|
539
|
+
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
|
540
|
+
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
|
432
541
|
},
|
|
433
542
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
|
434
543
|
/* .lhs_type = */ GGML_TYPE_F32,
|
|
@@ -443,6 +552,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|
|
443
552
|
ggml_kleidiai_kernels * kernel = nullptr;
|
|
444
553
|
|
|
445
554
|
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
|
555
|
+
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
|
446
556
|
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
|
447
557
|
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
|
448
558
|
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
|
@@ -452,6 +562,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|
|
452
562
|
break;
|
|
453
563
|
}
|
|
454
564
|
}
|
|
565
|
+
#endif
|
|
455
566
|
}
|
|
456
567
|
|
|
457
568
|
return kernel;
|
|
@@ -460,12 +571,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|
|
460
571
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
|
|
461
572
|
ggml_kleidiai_kernels * kernels = nullptr;
|
|
462
573
|
|
|
574
|
+
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
|
463
575
|
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
|
464
576
|
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
|
465
577
|
kernels = &gemm_gemv_kernels[i];
|
|
466
578
|
break;
|
|
467
579
|
}
|
|
468
580
|
}
|
|
581
|
+
#endif
|
|
469
582
|
|
|
470
583
|
return kernels;
|
|
471
584
|
}
|
|
@@ -4,8 +4,6 @@
|
|
|
4
4
|
|
|
5
5
|
#pragma once
|
|
6
6
|
|
|
7
|
-
#include <functional>
|
|
8
|
-
#include <variant>
|
|
9
7
|
#include "ggml.h"
|
|
10
8
|
|
|
11
9
|
enum cpu_feature {
|
|
@@ -15,6 +13,7 @@ enum cpu_feature {
|
|
|
15
13
|
CPU_FEATURE_SVE = 4,
|
|
16
14
|
CPU_FEATURE_SME = 8
|
|
17
15
|
};
|
|
16
|
+
|
|
18
17
|
inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
|
|
19
18
|
lhs = static_cast<cpu_feature>(lhs | rhs);
|
|
20
19
|
return lhs;
|
|
@@ -30,63 +29,52 @@ struct kernel_info {
|
|
|
30
29
|
size_t (*get_nr)(void);
|
|
31
30
|
size_t (*get_kr)(void);
|
|
32
31
|
size_t (*get_sr)(void);
|
|
33
|
-
|
|
34
|
-
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
|
|
35
|
-
std::function<size_t(size_t m_idx, size_t k)>
|
|
36
|
-
> get_lhs_offset;
|
|
37
|
-
std::variant<
|
|
38
|
-
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
|
|
39
|
-
std::function<size_t(size_t n_idx, size_t k)>
|
|
40
|
-
> get_rhs_packed_offset;
|
|
32
|
+
|
|
41
33
|
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
|
|
42
34
|
size_t (*get_dst_size)(size_t m, size_t n);
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
35
|
+
|
|
36
|
+
size_t (*get_lhs_offset_ex)(size_t m_idx, size_t k, size_t bl);
|
|
37
|
+
|
|
38
|
+
size_t (*get_rhs_packed_offset_ex)(size_t n_idx, size_t k, size_t bl);
|
|
39
|
+
|
|
40
|
+
void (*run_kernel_ex)(
|
|
41
|
+
size_t m, size_t n, size_t k, size_t bl,
|
|
42
|
+
const void* lhs_packed, const void* rhs_packed,
|
|
43
|
+
void* dst, size_t dst_stride_row, size_t dst_stride_col,
|
|
44
|
+
float clamp_min, float clamp_max);
|
|
49
45
|
};
|
|
50
46
|
|
|
51
47
|
struct lhs_packing_info {
|
|
52
48
|
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
> packed_size;
|
|
61
|
-
std::variant<
|
|
62
|
-
std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
|
63
|
-
size_t lhs_stride, void* lhs_packed)>,
|
|
64
|
-
std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
|
|
65
|
-
void* lhs_packed)>
|
|
66
|
-
> pack_func;
|
|
49
|
+
|
|
50
|
+
size_t (*get_packed_offset_ex)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
|
51
|
+
|
|
52
|
+
size_t (*packed_size_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
|
53
|
+
|
|
54
|
+
void (*pack_func_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
|
|
55
|
+
size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed);
|
|
67
56
|
};
|
|
68
57
|
|
|
69
58
|
struct rhs_packing_info {
|
|
70
|
-
std::variant<
|
|
71
|
-
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
|
|
72
|
-
std::function<size_t(size_t n, size_t k)>
|
|
73
|
-
> packed_size;
|
|
74
59
|
size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
60
|
+
|
|
61
|
+
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out,
|
|
62
|
+
size_t nr_pack, size_t packed_row_stride, size_t kr, size_t bl,
|
|
63
|
+
size_t num_bytes_multiplier);
|
|
64
|
+
|
|
65
|
+
size_t (*packed_size_ex)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
|
|
66
|
+
|
|
67
|
+
size_t (*packed_stride_ex)(size_t k, size_t nr, size_t kr, size_t bl);
|
|
68
|
+
|
|
69
|
+
void (*pack_func_ex)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,
|
|
70
|
+
size_t rhs_stride, const void * rhs, const void * bias, const void * scale, void * rhs_packed, size_t extra_bytes, const void * params);
|
|
83
71
|
};
|
|
84
72
|
|
|
85
73
|
struct ggml_kleidiai_kernels {
|
|
86
|
-
kernel_info
|
|
74
|
+
kernel_info gemm;
|
|
87
75
|
lhs_packing_info gemm_lhs_info;
|
|
88
76
|
|
|
89
|
-
kernel_info
|
|
77
|
+
kernel_info gemv;
|
|
90
78
|
lhs_packing_info gemv_lhs_info;
|
|
91
79
|
|
|
92
80
|
rhs_packing_info rhs_info;
|