@fugood/llama.node 1.1.10 → 1.2.0-rc.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. package/CMakeLists.txt +5 -8
  2. package/lib/binding.ts +20 -2
  3. package/lib/index.js +2 -2
  4. package/lib/index.ts +2 -2
  5. package/package.json +20 -16
  6. package/src/DecodeAudioTokenWorker.cpp +23 -26
  7. package/src/DecodeAudioTokenWorker.h +6 -8
  8. package/src/DetokenizeWorker.cpp +5 -8
  9. package/src/DetokenizeWorker.h +6 -5
  10. package/src/DisposeWorker.cpp +23 -3
  11. package/src/DisposeWorker.h +4 -2
  12. package/src/EmbeddingWorker.cpp +9 -35
  13. package/src/EmbeddingWorker.h +3 -2
  14. package/src/LlamaCompletionWorker.cpp +217 -315
  15. package/src/LlamaCompletionWorker.h +6 -12
  16. package/src/LlamaContext.cpp +174 -388
  17. package/src/LlamaContext.h +8 -13
  18. package/src/LoadSessionWorker.cpp +22 -19
  19. package/src/LoadSessionWorker.h +3 -2
  20. package/src/RerankWorker.h +3 -2
  21. package/src/SaveSessionWorker.cpp +22 -19
  22. package/src/SaveSessionWorker.h +3 -2
  23. package/src/TokenizeWorker.cpp +38 -35
  24. package/src/TokenizeWorker.h +12 -3
  25. package/src/common.hpp +0 -458
  26. package/src/llama.cpp/common/arg.cpp +67 -37
  27. package/src/llama.cpp/common/chat.cpp +263 -2
  28. package/src/llama.cpp/common/chat.h +4 -0
  29. package/src/llama.cpp/common/common.cpp +10 -3
  30. package/src/llama.cpp/common/common.h +5 -2
  31. package/src/llama.cpp/common/log.cpp +53 -2
  32. package/src/llama.cpp/common/log.h +10 -4
  33. package/src/llama.cpp/common/sampling.cpp +23 -2
  34. package/src/llama.cpp/common/sampling.h +3 -1
  35. package/src/llama.cpp/common/speculative.cpp +1 -1
  36. package/src/llama.cpp/ggml/CMakeLists.txt +4 -3
  37. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -0
  38. package/src/llama.cpp/ggml/include/ggml-cpu.h +0 -1
  39. package/src/llama.cpp/ggml/include/ggml.h +50 -1
  40. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +19 -16
  41. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +210 -96
  42. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +1 -7
  43. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +11 -37
  44. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +3 -4
  45. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +43 -6
  46. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +4 -1
  47. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +18 -18
  48. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  49. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +234 -16
  50. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  51. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +80 -51
  52. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +161 -20
  53. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +399 -50
  54. package/src/llama.cpp/include/llama.h +32 -7
  55. package/src/llama.cpp/src/llama-adapter.cpp +101 -4
  56. package/src/llama.cpp/src/llama-adapter.h +6 -0
  57. package/src/llama.cpp/src/llama-arch.cpp +69 -2
  58. package/src/llama.cpp/src/llama-arch.h +6 -0
  59. package/src/llama.cpp/src/llama-context.cpp +92 -45
  60. package/src/llama.cpp/src/llama-context.h +1 -5
  61. package/src/llama.cpp/src/llama-graph.cpp +74 -19
  62. package/src/llama.cpp/src/llama-graph.h +10 -1
  63. package/src/llama.cpp/src/llama-hparams.cpp +37 -0
  64. package/src/llama.cpp/src/llama-hparams.h +9 -3
  65. package/src/llama.cpp/src/llama-impl.h +2 -0
  66. package/src/llama.cpp/src/llama-kv-cache.cpp +33 -120
  67. package/src/llama.cpp/src/llama-kv-cache.h +4 -13
  68. package/src/llama.cpp/src/llama-model-loader.cpp +1 -0
  69. package/src/llama.cpp/src/llama-model.cpp +434 -21
  70. package/src/llama.cpp/src/llama-model.h +1 -1
  71. package/src/llama.cpp/src/llama-sampling.cpp +226 -126
  72. package/src/llama.cpp/src/llama-vocab.cpp +1 -1
  73. package/src/llama.cpp/src/llama.cpp +12 -0
  74. package/src/anyascii.c +0 -22223
  75. package/src/anyascii.h +0 -42
  76. package/src/tts_utils.cpp +0 -371
  77. package/src/tts_utils.h +0 -103
@@ -348,8 +348,10 @@ static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t *
348
348
  long pages = sysconf(_SC_PHYS_PAGES);
349
349
  long page_size = sysconf(_SC_PAGE_SIZE);
350
350
  *total = pages * page_size;
351
+
352
+ // "free" system memory is ill-defined, for practical purposes assume that all of it is free:
351
353
  *free = *total;
352
- #endif
354
+ #endif // _WIN32
353
355
 
354
356
  GGML_UNUSED(dev);
355
357
  }
@@ -576,9 +578,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
576
578
  if (ggml_cpu_has_vxe()) {
577
579
  features.push_back({ "VXE", "1" });
578
580
  }
579
- if (ggml_cpu_has_nnpa()) {
580
- features.push_back({ "NNPA", "1" });
581
- }
582
581
  if (ggml_cpu_has_wasm_simd()) {
583
582
  features.push_back({ "WASM_SIMD", "1" });
584
583
  }
@@ -14,6 +14,7 @@
14
14
 
15
15
  #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
16
16
  #include "kai_lhs_quant_pack_qsi8d32p_f32.h"
17
+ #include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
17
18
  #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
18
19
 
19
20
  #include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
@@ -127,6 +128,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
127
128
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
128
129
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
129
130
  },
131
+ /* .gemm_lhs_info = */ {
132
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
133
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
134
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
135
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
136
+ },
130
137
  /* SME GEMV */
131
138
  /* .kern_info = */ {
132
139
  /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
@@ -141,7 +148,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
141
148
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
142
149
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
143
150
  },
144
- /* .lhs_info = */ {
151
+ /* .gemv_lhs_info = */ {
145
152
  /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
146
153
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
147
154
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
@@ -173,6 +180,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
173
180
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
174
181
  /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
175
182
  },
183
+ /* .gemm_lhs_info = */ {
184
+ /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
185
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
186
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
187
+ /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
188
+ },
176
189
  /* SME GEMV */
177
190
  /* .kern_info = */ {
178
191
  /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
@@ -187,7 +200,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
187
200
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
188
201
  /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
189
202
  },
190
- /* .lhs_info = */ {
203
+ /* .gemv_lhs_info = */ {
191
204
  /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
192
205
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
193
206
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
@@ -222,6 +235,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
222
235
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
223
236
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
224
237
  },
238
+ /* .gemm_lhs_info = */ {
239
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
240
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
241
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
242
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
243
+ },
225
244
  /* DOTPROD GEMV */
226
245
  /* .kern_info = */ {
227
246
  /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
@@ -236,7 +255,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
236
255
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
237
256
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
238
257
  },
239
- /* .lhs_info = */ {
258
+ /* .gemv_lhs_info = */ {
240
259
  /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
241
260
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
242
261
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@@ -270,6 +289,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
270
289
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
271
290
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
272
291
  },
292
+ /* .gemm_lhs_info = */ {
293
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
294
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
295
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
296
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
297
+ },
273
298
  /* i8mm GEMV */
274
299
  /* .kern_info = */ {
275
300
  /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
@@ -284,7 +309,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
284
309
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
285
310
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
286
311
  },
287
- /* .lhs_info = */ {
312
+ /* .gemv_lhs_info = */ {
288
313
  /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
289
314
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
290
315
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@@ -319,6 +344,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
319
344
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
320
345
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
321
346
  },
347
+ /* .gemm_lhs_info = */ {
348
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
349
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
350
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
351
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
352
+ },
322
353
  /* i8mm GEMV */
323
354
  /* .kern_info = */ {
324
355
  /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
@@ -333,7 +364,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
333
364
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
334
365
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
335
366
  },
336
- /* .lhs_info = */ {
367
+ /* .gemv_lhs_info = */ {
337
368
  /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
338
369
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
339
370
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@@ -367,6 +398,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
367
398
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
368
399
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
369
400
  },
401
+ /* .gemm_lhs_info = */ {
402
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
403
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
404
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
405
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
406
+ },
370
407
  /* DOTPROD GEMV */
371
408
  /* .kern_info = */ {
372
409
  /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
@@ -381,7 +418,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
381
418
  /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
382
419
  /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
383
420
  },
384
- /* .lhs_info = */ {
421
+ /* .gemv_lhs_info = */ {
385
422
  /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
386
423
  /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
387
424
  /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@@ -84,8 +84,11 @@ struct rhs_packing_info {
84
84
 
85
85
  struct ggml_kleidiai_kernels {
86
86
  kernel_info gemm;
87
+ lhs_packing_info gemm_lhs_info;
88
+
87
89
  kernel_info gemv;
88
- lhs_packing_info lhs_info;
90
+ lhs_packing_info gemv_lhs_info;
91
+
89
92
  rhs_packing_info rhs_info;
90
93
 
91
94
  cpu_feature required_cpu;
@@ -123,7 +123,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
123
123
  }
124
124
  ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
125
125
  GGML_ASSERT(kernels);
126
- kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
126
+ bool is_gemv = op->src[1]->ne[1] == 1;
127
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
128
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
127
129
 
128
130
  size_t k = op->src[0]->ne[0];
129
131
  size_t n = op->src[0]->ne[1];
@@ -134,9 +136,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
134
136
  size_t sr = kernel->get_sr();
135
137
 
136
138
  if (kernels->rhs_type == GGML_TYPE_Q4_0) {
137
- size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
139
+ size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
138
140
  } else if (kernels->rhs_type == GGML_TYPE_F16) {
139
- size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
141
+ size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr) +
140
142
  variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
141
143
  k * n * sizeof(float) + n * sizeof(float);
142
144
  } else {
@@ -152,7 +154,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
152
154
  if (dst->src[0]->type == GGML_TYPE_Q4_0) {
153
155
  return compute_forward_q4_0(params, dst);
154
156
  } else if (dst->src[0]->type == GGML_TYPE_F16) {
155
- return compute_forward_kv_cache(params, dst);
157
+ return compute_forward_fp16(params, dst);
156
158
  }
157
159
  } else if (dst->op == GGML_OP_GET_ROWS) {
158
160
  if (dst->src[0]->type == GGML_TYPE_Q4_0) {
@@ -162,7 +164,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
162
164
  return false;
163
165
  }
164
166
 
165
- bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
167
+ bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
166
168
  static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
167
169
 
168
170
  const ggml_tensor * src0 = dst->src[0];
@@ -173,7 +175,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
173
175
  ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
174
176
  GGML_ASSERT(kernels);
175
177
 
176
- kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
178
+ bool is_gemv = src1->ne[1] == 1;
179
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
180
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
177
181
  GGML_ASSERT(kernel);
178
182
 
179
183
  const int nth = params->nth;
@@ -198,7 +202,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
198
202
  const int64_t kr = static_cast<int64_t>(kernel->get_kr());
199
203
  const int64_t sr = static_cast<int64_t>(kernel->get_sr());
200
204
 
201
- const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
205
+ const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr);
202
206
  const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
203
207
  const size_t kxn_size = k * n * sizeof(float);
204
208
  const size_t bias_size = n * sizeof(float);
@@ -229,12 +233,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
229
233
  const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
230
234
 
231
235
  const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
232
- const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
236
+ const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, mr, kr, sr);
233
237
 
234
238
  const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
235
239
  void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
236
240
 
237
- variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
241
+ variant_call<void>(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
238
242
  }
239
243
  }
240
244
 
@@ -306,8 +310,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
306
310
  ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
307
311
  GGML_ASSERT(kernels);
308
312
 
309
- kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
310
- lhs_packing_info * lhs_info = &kernels->lhs_info;
313
+ bool is_gemv = src1->ne[1] == 1;
314
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
315
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
311
316
 
312
317
  GGML_ASSERT(kernel);
313
318
 
@@ -529,13 +534,8 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
529
534
  if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
530
535
  return (ggml::cpu::tensor_traits *) op->src[0]->extra;
531
536
  }
532
- else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
533
- op->src[0]->op == GGML_OP_VIEW &&
534
- (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
535
- op->src[1]->ne[1] > 1) {
536
- if ((op->src[0]->nb[0] != 2) ||
537
- (op->src[1]->nb[0] != 4) ||
538
- (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
537
+ else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
538
+ if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
539
539
  (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
540
540
  return nullptr;
541
541
  }