@fugood/llama.node 1.1.10 → 1.2.0-rc.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +5 -8
- package/lib/binding.ts +20 -2
- package/lib/index.js +2 -2
- package/lib/index.ts +2 -2
- package/package.json +20 -16
- package/src/DecodeAudioTokenWorker.cpp +23 -26
- package/src/DecodeAudioTokenWorker.h +6 -8
- package/src/DetokenizeWorker.cpp +5 -8
- package/src/DetokenizeWorker.h +6 -5
- package/src/DisposeWorker.cpp +23 -3
- package/src/DisposeWorker.h +4 -2
- package/src/EmbeddingWorker.cpp +9 -35
- package/src/EmbeddingWorker.h +3 -2
- package/src/LlamaCompletionWorker.cpp +217 -315
- package/src/LlamaCompletionWorker.h +6 -12
- package/src/LlamaContext.cpp +174 -388
- package/src/LlamaContext.h +8 -13
- package/src/LoadSessionWorker.cpp +22 -19
- package/src/LoadSessionWorker.h +3 -2
- package/src/RerankWorker.h +3 -2
- package/src/SaveSessionWorker.cpp +22 -19
- package/src/SaveSessionWorker.h +3 -2
- package/src/TokenizeWorker.cpp +38 -35
- package/src/TokenizeWorker.h +12 -3
- package/src/common.hpp +0 -458
- package/src/llama.cpp/common/arg.cpp +67 -37
- package/src/llama.cpp/common/chat.cpp +263 -2
- package/src/llama.cpp/common/chat.h +4 -0
- package/src/llama.cpp/common/common.cpp +10 -3
- package/src/llama.cpp/common/common.h +5 -2
- package/src/llama.cpp/common/log.cpp +53 -2
- package/src/llama.cpp/common/log.h +10 -4
- package/src/llama.cpp/common/sampling.cpp +23 -2
- package/src/llama.cpp/common/sampling.h +3 -1
- package/src/llama.cpp/common/speculative.cpp +1 -1
- package/src/llama.cpp/ggml/CMakeLists.txt +4 -3
- package/src/llama.cpp/ggml/include/ggml-backend.h +3 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +0 -1
- package/src/llama.cpp/ggml/include/ggml.h +50 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +19 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +210 -96
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +1 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +11 -37
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +3 -4
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +43 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +4 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +18 -18
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +234 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +80 -51
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +161 -20
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +399 -50
- package/src/llama.cpp/include/llama.h +32 -7
- package/src/llama.cpp/src/llama-adapter.cpp +101 -4
- package/src/llama.cpp/src/llama-adapter.h +6 -0
- package/src/llama.cpp/src/llama-arch.cpp +69 -2
- package/src/llama.cpp/src/llama-arch.h +6 -0
- package/src/llama.cpp/src/llama-context.cpp +92 -45
- package/src/llama.cpp/src/llama-context.h +1 -5
- package/src/llama.cpp/src/llama-graph.cpp +74 -19
- package/src/llama.cpp/src/llama-graph.h +10 -1
- package/src/llama.cpp/src/llama-hparams.cpp +37 -0
- package/src/llama.cpp/src/llama-hparams.h +9 -3
- package/src/llama.cpp/src/llama-impl.h +2 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +33 -120
- package/src/llama.cpp/src/llama-kv-cache.h +4 -13
- package/src/llama.cpp/src/llama-model-loader.cpp +1 -0
- package/src/llama.cpp/src/llama-model.cpp +434 -21
- package/src/llama.cpp/src/llama-model.h +1 -1
- package/src/llama.cpp/src/llama-sampling.cpp +226 -126
- package/src/llama.cpp/src/llama-vocab.cpp +1 -1
- package/src/llama.cpp/src/llama.cpp +12 -0
- package/src/anyascii.c +0 -22223
- package/src/anyascii.h +0 -42
- package/src/tts_utils.cpp +0 -371
- package/src/tts_utils.h +0 -103
|
@@ -348,8 +348,10 @@ static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t *
|
|
|
348
348
|
long pages = sysconf(_SC_PHYS_PAGES);
|
|
349
349
|
long page_size = sysconf(_SC_PAGE_SIZE);
|
|
350
350
|
*total = pages * page_size;
|
|
351
|
+
|
|
352
|
+
// "free" system memory is ill-defined, for practical purposes assume that all of it is free:
|
|
351
353
|
*free = *total;
|
|
352
|
-
#endif
|
|
354
|
+
#endif // _WIN32
|
|
353
355
|
|
|
354
356
|
GGML_UNUSED(dev);
|
|
355
357
|
}
|
|
@@ -576,9 +578,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
|
|
576
578
|
if (ggml_cpu_has_vxe()) {
|
|
577
579
|
features.push_back({ "VXE", "1" });
|
|
578
580
|
}
|
|
579
|
-
if (ggml_cpu_has_nnpa()) {
|
|
580
|
-
features.push_back({ "NNPA", "1" });
|
|
581
|
-
}
|
|
582
581
|
if (ggml_cpu_has_wasm_simd()) {
|
|
583
582
|
features.push_back({ "WASM_SIMD", "1" });
|
|
584
583
|
}
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
|
16
16
|
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
|
17
|
+
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
|
|
17
18
|
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
|
18
19
|
|
|
19
20
|
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
|
|
@@ -127,6 +128,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
127
128
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
128
129
|
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
129
130
|
},
|
|
131
|
+
/* .gemm_lhs_info = */ {
|
|
132
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
133
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
134
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
135
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
136
|
+
},
|
|
130
137
|
/* SME GEMV */
|
|
131
138
|
/* .kern_info = */ {
|
|
132
139
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
@@ -141,7 +148,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
141
148
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
142
149
|
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
143
150
|
},
|
|
144
|
-
/* .
|
|
151
|
+
/* .gemv_lhs_info = */ {
|
|
145
152
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
146
153
|
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
147
154
|
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
@@ -173,6 +180,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
173
180
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
174
181
|
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
175
182
|
},
|
|
183
|
+
/* .gemm_lhs_info = */ {
|
|
184
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
|
185
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
|
186
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
|
|
187
|
+
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
|
188
|
+
},
|
|
176
189
|
/* SME GEMV */
|
|
177
190
|
/* .kern_info = */ {
|
|
178
191
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
@@ -187,7 +200,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
187
200
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
188
201
|
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
189
202
|
},
|
|
190
|
-
/* .
|
|
203
|
+
/* .gemv_lhs_info = */ {
|
|
191
204
|
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
|
192
205
|
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
|
193
206
|
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
|
|
@@ -222,6 +235,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
222
235
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
223
236
|
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
224
237
|
},
|
|
238
|
+
/* .gemm_lhs_info = */ {
|
|
239
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
240
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
241
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
242
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
243
|
+
},
|
|
225
244
|
/* DOTPROD GEMV */
|
|
226
245
|
/* .kern_info = */ {
|
|
227
246
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
@@ -236,7 +255,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
236
255
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
237
256
|
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
238
257
|
},
|
|
239
|
-
/* .
|
|
258
|
+
/* .gemv_lhs_info = */ {
|
|
240
259
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
241
260
|
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
242
261
|
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
@@ -270,6 +289,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
270
289
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
271
290
|
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
272
291
|
},
|
|
292
|
+
/* .gemm_lhs_info = */ {
|
|
293
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
|
294
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
|
295
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
|
296
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
|
297
|
+
},
|
|
273
298
|
/* i8mm GEMV */
|
|
274
299
|
/* .kern_info = */ {
|
|
275
300
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
@@ -284,7 +309,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
284
309
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
285
310
|
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
286
311
|
},
|
|
287
|
-
/* .
|
|
312
|
+
/* .gemv_lhs_info = */ {
|
|
288
313
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
289
314
|
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
290
315
|
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
@@ -319,6 +344,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
319
344
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
320
345
|
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
321
346
|
},
|
|
347
|
+
/* .gemm_lhs_info = */ {
|
|
348
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
|
349
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
|
350
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
|
351
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
|
352
|
+
},
|
|
322
353
|
/* i8mm GEMV */
|
|
323
354
|
/* .kern_info = */ {
|
|
324
355
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
@@ -333,7 +364,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
333
364
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
334
365
|
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
335
366
|
},
|
|
336
|
-
/* .
|
|
367
|
+
/* .gemv_lhs_info = */ {
|
|
337
368
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
338
369
|
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
339
370
|
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
@@ -367,6 +398,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
367
398
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
368
399
|
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
369
400
|
},
|
|
401
|
+
/* .gemm_lhs_info = */ {
|
|
402
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
403
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
404
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
405
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
406
|
+
},
|
|
370
407
|
/* DOTPROD GEMV */
|
|
371
408
|
/* .kern_info = */ {
|
|
372
409
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
@@ -381,7 +418,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
381
418
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
382
419
|
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
383
420
|
},
|
|
384
|
-
/* .
|
|
421
|
+
/* .gemv_lhs_info = */ {
|
|
385
422
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
386
423
|
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
387
424
|
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
@@ -84,8 +84,11 @@ struct rhs_packing_info {
|
|
|
84
84
|
|
|
85
85
|
struct ggml_kleidiai_kernels {
|
|
86
86
|
kernel_info gemm;
|
|
87
|
+
lhs_packing_info gemm_lhs_info;
|
|
88
|
+
|
|
87
89
|
kernel_info gemv;
|
|
88
|
-
lhs_packing_info
|
|
90
|
+
lhs_packing_info gemv_lhs_info;
|
|
91
|
+
|
|
89
92
|
rhs_packing_info rhs_info;
|
|
90
93
|
|
|
91
94
|
cpu_feature required_cpu;
|
|
@@ -123,7 +123,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
123
123
|
}
|
|
124
124
|
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
|
125
125
|
GGML_ASSERT(kernels);
|
|
126
|
-
|
|
126
|
+
bool is_gemv = op->src[1]->ne[1] == 1;
|
|
127
|
+
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
128
|
+
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
127
129
|
|
|
128
130
|
size_t k = op->src[0]->ne[0];
|
|
129
131
|
size_t n = op->src[0]->ne[1];
|
|
@@ -134,9 +136,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
134
136
|
size_t sr = kernel->get_sr();
|
|
135
137
|
|
|
136
138
|
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
|
137
|
-
size = variant_call<size_t>(
|
|
139
|
+
size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
|
|
138
140
|
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
|
139
|
-
size = variant_call<size_t>(
|
|
141
|
+
size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr) +
|
|
140
142
|
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
|
|
141
143
|
k * n * sizeof(float) + n * sizeof(float);
|
|
142
144
|
} else {
|
|
@@ -152,7 +154,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
152
154
|
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
|
153
155
|
return compute_forward_q4_0(params, dst);
|
|
154
156
|
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
|
155
|
-
return
|
|
157
|
+
return compute_forward_fp16(params, dst);
|
|
156
158
|
}
|
|
157
159
|
} else if (dst->op == GGML_OP_GET_ROWS) {
|
|
158
160
|
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
|
@@ -162,7 +164,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
162
164
|
return false;
|
|
163
165
|
}
|
|
164
166
|
|
|
165
|
-
bool
|
|
167
|
+
bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
166
168
|
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
|
|
167
169
|
|
|
168
170
|
const ggml_tensor * src0 = dst->src[0];
|
|
@@ -173,7 +175,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
173
175
|
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
|
174
176
|
GGML_ASSERT(kernels);
|
|
175
177
|
|
|
176
|
-
|
|
178
|
+
bool is_gemv = src1->ne[1] == 1;
|
|
179
|
+
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
180
|
+
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
177
181
|
GGML_ASSERT(kernel);
|
|
178
182
|
|
|
179
183
|
const int nth = params->nth;
|
|
@@ -198,7 +202,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
198
202
|
const int64_t kr = static_cast<int64_t>(kernel->get_kr());
|
|
199
203
|
const int64_t sr = static_cast<int64_t>(kernel->get_sr());
|
|
200
204
|
|
|
201
|
-
const size_t lhs_packed_size = variant_call<size_t>(
|
|
205
|
+
const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr);
|
|
202
206
|
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
|
|
203
207
|
const size_t kxn_size = k * n * sizeof(float);
|
|
204
208
|
const size_t bias_size = n * sizeof(float);
|
|
@@ -229,12 +233,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
229
233
|
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
|
230
234
|
|
|
231
235
|
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
|
|
232
|
-
const size_t lhs_packed_offset = variant_call<size_t>(
|
|
236
|
+
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, mr, kr, sr);
|
|
233
237
|
|
|
234
238
|
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
|
|
235
239
|
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
|
|
236
240
|
|
|
237
|
-
variant_call<void>(
|
|
241
|
+
variant_call<void>(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
|
|
238
242
|
}
|
|
239
243
|
}
|
|
240
244
|
|
|
@@ -306,8 +310,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
306
310
|
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
|
307
311
|
GGML_ASSERT(kernels);
|
|
308
312
|
|
|
309
|
-
|
|
310
|
-
|
|
313
|
+
bool is_gemv = src1->ne[1] == 1;
|
|
314
|
+
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
|
315
|
+
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
|
311
316
|
|
|
312
317
|
GGML_ASSERT(kernel);
|
|
313
318
|
|
|
@@ -529,13 +534,8 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
|
529
534
|
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
|
530
535
|
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
|
531
536
|
}
|
|
532
|
-
else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
|
|
533
|
-
|
|
534
|
-
(op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
|
|
535
|
-
op->src[1]->ne[1] > 1) {
|
|
536
|
-
if ((op->src[0]->nb[0] != 2) ||
|
|
537
|
-
(op->src[1]->nb[0] != 4) ||
|
|
538
|
-
(op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
|
|
537
|
+
else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
|
|
538
|
+
if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
|
|
539
539
|
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
|
|
540
540
|
return nullptr;
|
|
541
541
|
}
|