@fugood/llama.node 1.1.9 → 1.1.11

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/lib/binding.ts +7 -1
  2. package/package.json +14 -14
  3. package/scripts/llama.cpp.patch +15 -5
  4. package/src/LlamaCompletionWorker.cpp +12 -3
  5. package/src/LlamaCompletionWorker.h +3 -1
  6. package/src/LlamaContext.cpp +20 -2
  7. package/src/llama.cpp/common/arg.cpp +29 -19
  8. package/src/llama.cpp/common/chat.cpp +153 -3
  9. package/src/llama.cpp/common/chat.h +1 -0
  10. package/src/llama.cpp/common/common.cpp +10 -3
  11. package/src/llama.cpp/common/common.h +4 -1
  12. package/src/llama.cpp/ggml/CMakeLists.txt +1 -1
  13. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +6 -4
  14. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +1 -1
  15. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +43 -6
  16. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +4 -1
  17. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +14 -9
  18. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  19. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +16 -12
  20. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +39 -14
  21. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +20 -1
  22. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +103 -1
  23. package/src/llama.cpp/include/llama.h +27 -1
  24. package/src/llama.cpp/src/llama-adapter.cpp +68 -4
  25. package/src/llama.cpp/src/llama-adapter.h +3 -0
  26. package/src/llama.cpp/src/llama-arch.cpp +46 -2
  27. package/src/llama.cpp/src/llama-arch.h +4 -0
  28. package/src/llama.cpp/src/llama-context.cpp +80 -39
  29. package/src/llama.cpp/src/llama-context.h +0 -4
  30. package/src/llama.cpp/src/llama-graph.cpp +20 -10
  31. package/src/llama.cpp/src/llama-graph.h +2 -1
  32. package/src/llama.cpp/src/llama-hparams.cpp +25 -0
  33. package/src/llama.cpp/src/llama-hparams.h +6 -0
  34. package/src/llama.cpp/src/llama-impl.h +2 -0
  35. package/src/llama.cpp/src/llama-kv-cache-iswa.cpp +24 -7
  36. package/src/llama.cpp/src/llama-kv-cache-iswa.h +4 -2
  37. package/src/llama.cpp/src/llama-kv-cache.cpp +67 -130
  38. package/src/llama.cpp/src/llama-kv-cache.h +16 -28
  39. package/src/llama.cpp/src/llama-memory-hybrid.cpp +29 -28
  40. package/src/llama.cpp/src/llama-memory-hybrid.h +18 -22
  41. package/src/llama.cpp/src/llama-memory-recurrent.cpp +7 -7
  42. package/src/llama.cpp/src/llama-memory-recurrent.h +7 -11
  43. package/src/llama.cpp/src/llama-memory.h +8 -0
  44. package/src/llama.cpp/src/llama-model-loader.cpp +1 -0
  45. package/src/llama.cpp/src/llama-model.cpp +302 -31
  46. package/src/llama.cpp/src/llama-model.h +1 -0
  47. package/src/llama.cpp/src/llama-vocab.cpp +1 -1
  48. package/src/llama.cpp/src/llama.cpp +12 -0
@@ -435,7 +435,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
435
435
  )
436
436
  if (GGML_RVV)
437
437
  if (GGML_XTHEADVECTOR)
438
- list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d)
438
+ list(APPEND ARCH_FLAGS -march=rv64gc_zfhmin_xtheadvector -mabi=lp64d)
439
439
  elseif (GGML_RV_ZFH)
440
440
  list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d)
441
441
  else()
@@ -497,9 +497,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
497
497
 
498
498
  # Fetch KleidiAI sources:
499
499
  include(FetchContent)
500
- set(KLEIDIAI_COMMIT_TAG "v1.11.0")
500
+ set(KLEIDIAI_COMMIT_TAG "v1.13.0")
501
501
  set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
502
- set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2")
502
+ set(KLEIDIAI_ARCHIVE_MD5 "d82a8de939d9814621a5ba23907bdac1")
503
503
 
504
504
  if (POLICY CMP0135)
505
505
  cmake_policy(SET CMP0135 NEW)
@@ -555,6 +555,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
555
555
 
556
556
  list(APPEND GGML_KLEIDIAI_SOURCES
557
557
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
558
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
558
559
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
559
560
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
560
561
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
@@ -576,7 +577,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
576
577
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
577
578
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
578
579
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
579
- ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c)
580
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
581
+ ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)
580
582
  set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
581
583
  endif()
582
584
 
@@ -489,7 +489,7 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
489
489
  /**
490
490
  * @see https://github.com/ggml-org/llama.cpp/pull/14037
491
491
  */
492
- inline float vec_hsum(float32x4_t v) {
492
+ inline static float vec_hsum(float32x4_t v) {
493
493
  float32x4_t v_temp = v + vec_reve(v);
494
494
  return v_temp[0] + v_temp[1];
495
495
  }
@@ -14,6 +14,7 @@
14
14
 
15
15
  #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
16
16
  #include "kai_lhs_quant_pack_qsi8d32p_f32.h"
17
+ #include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
17
18
  #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
18
19
 
19
20
  #include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
@@ -127,6 +128,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
127
128
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
128
129
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
129
130
  },
131
+ /* .gemm_lhs_info = */ {
132
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
133
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
134
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
135
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
136
+ },
130
137
  /* SME GEMV */
131
138
  /* .kern_info = */ {
132
139
  /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
@@ -141,7 +148,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
141
148
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
142
149
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
143
150
  },
144
- /* .lhs_info = */ {
151
+ /* .gemv_lhs_info = */ {
145
152
  /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
146
153
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
147
154
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
@@ -173,6 +180,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
173
180
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
174
181
  /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
175
182
  },
183
+ /* .gemm_lhs_info = */ {
184
+ /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
185
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
186
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
187
+ /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
188
+ },
176
189
  /* SME GEMV */
177
190
  /* .kern_info = */ {
178
191
  /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
@@ -187,7 +200,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
187
200
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
188
201
  /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
189
202
  },
190
- /* .lhs_info = */ {
203
+ /* .gemv_lhs_info = */ {
191
204
  /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
192
205
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
193
206
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
@@ -222,6 +235,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
222
235
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
223
236
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
224
237
  },
238
+ /* .gemm_lhs_info = */ {
239
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
240
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
241
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
242
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
243
+ },
225
244
  /* DOTPROD GEMV */
226
245
  /* .kern_info = */ {
227
246
  /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
@@ -236,7 +255,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
236
255
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
237
256
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
238
257
  },
239
- /* .lhs_info = */ {
258
+ /* .gemv_lhs_info = */ {
240
259
  /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
241
260
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
242
261
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@@ -270,6 +289,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
270
289
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
271
290
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
272
291
  },
292
+ /* .gemm_lhs_info = */ {
293
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
294
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
295
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
296
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
297
+ },
273
298
  /* i8mm GEMV */
274
299
  /* .kern_info = */ {
275
300
  /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
@@ -284,7 +309,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
284
309
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
285
310
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
286
311
  },
287
- /* .lhs_info = */ {
312
+ /* .gemv_lhs_info = */ {
288
313
  /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
289
314
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
290
315
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@@ -319,6 +344,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
319
344
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
320
345
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
321
346
  },
347
+ /* .gemm_lhs_info = */ {
348
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
349
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
350
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
351
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
352
+ },
322
353
  /* i8mm GEMV */
323
354
  /* .kern_info = */ {
324
355
  /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
@@ -333,7 +364,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
333
364
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
334
365
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
335
366
  },
336
- /* .lhs_info = */ {
367
+ /* .gemv_lhs_info = */ {
337
368
  /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
338
369
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
339
370
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@@ -367,6 +398,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
367
398
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
368
399
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
369
400
  },
401
+ /* .gemm_lhs_info = */ {
402
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
403
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
404
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
405
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
406
+ },
370
407
  /* DOTPROD GEMV */
371
408
  /* .kern_info = */ {
372
409
  /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
@@ -381,7 +418,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
381
418
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
382
419
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
383
420
  },
384
- /* .lhs_info = */ {
421
+ /* .gemv_lhs_info = */ {
385
422
  /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
386
423
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
387
424
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@@ -84,8 +84,11 @@ struct rhs_packing_info {
84
84
 
85
85
  struct ggml_kleidiai_kernels {
86
86
  kernel_info gemm;
87
+ lhs_packing_info gemm_lhs_info;
88
+
87
89
  kernel_info gemv;
88
- lhs_packing_info lhs_info;
90
+ lhs_packing_info gemv_lhs_info;
91
+
89
92
  rhs_packing_info rhs_info;
90
93
 
91
94
  cpu_feature required_cpu;
@@ -123,7 +123,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
123
123
  }
124
124
  ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
125
125
  GGML_ASSERT(kernels);
126
- kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
126
+ bool is_gemv = op->src[1]->ne[1] == 1;
127
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
128
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
127
129
 
128
130
  size_t k = op->src[0]->ne[0];
129
131
  size_t n = op->src[0]->ne[1];
@@ -134,9 +136,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
134
136
  size_t sr = kernel->get_sr();
135
137
 
136
138
  if (kernels->rhs_type == GGML_TYPE_Q4_0) {
137
- size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
139
+ size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
138
140
  } else if (kernels->rhs_type == GGML_TYPE_F16) {
139
- size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
141
+ size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr) +
140
142
  variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
141
143
  k * n * sizeof(float) + n * sizeof(float);
142
144
  } else {
@@ -173,7 +175,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
173
175
  ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
174
176
  GGML_ASSERT(kernels);
175
177
 
176
- kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
178
+ bool is_gemv = src1->ne[1] == 1;
179
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
180
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
177
181
  GGML_ASSERT(kernel);
178
182
 
179
183
  const int nth = params->nth;
@@ -198,7 +202,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
198
202
  const int64_t kr = static_cast<int64_t>(kernel->get_kr());
199
203
  const int64_t sr = static_cast<int64_t>(kernel->get_sr());
200
204
 
201
- const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
205
+ const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr);
202
206
  const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
203
207
  const size_t kxn_size = k * n * sizeof(float);
204
208
  const size_t bias_size = n * sizeof(float);
@@ -229,12 +233,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
229
233
  const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
230
234
 
231
235
  const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
232
- const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
236
+ const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, mr, kr, sr);
233
237
 
234
238
  const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
235
239
  void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
236
240
 
237
- variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
241
+ variant_call<void>(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
238
242
  }
239
243
  }
240
244
 
@@ -306,8 +310,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
306
310
  ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
307
311
  GGML_ASSERT(kernels);
308
312
 
309
- kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
310
- lhs_packing_info * lhs_info = &kernels->lhs_info;
313
+ bool is_gemv = src1->ne[1] == 1;
314
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
315
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
311
316
 
312
317
  GGML_ASSERT(kernel);
313
318