@fugood/llama.node 0.3.16 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (281) hide show
  1. package/CMakeLists.txt +6 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +44 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +374 -19
  24. package/src/LlamaCompletionWorker.h +31 -10
  25. package/src/LlamaContext.cpp +216 -7
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
  29. package/src/llama.cpp/.github/workflows/build.yml +89 -767
  30. package/src/llama.cpp/.github/workflows/docker.yml +9 -6
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +19 -23
  33. package/src/llama.cpp/CMakeLists.txt +11 -1
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +35 -4
  37. package/src/llama.cpp/common/arg.cpp +844 -121
  38. package/src/llama.cpp/common/arg.h +9 -0
  39. package/src/llama.cpp/common/chat.cpp +129 -107
  40. package/src/llama.cpp/common/chat.h +2 -0
  41. package/src/llama.cpp/common/common.cpp +64 -518
  42. package/src/llama.cpp/common/common.h +35 -45
  43. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  44. package/src/llama.cpp/common/llguidance.cpp +31 -47
  45. package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
  46. package/src/llama.cpp/common/minja/minja.hpp +186 -127
  47. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  48. package/src/llama.cpp/common/regex-partial.h +56 -0
  49. package/src/llama.cpp/common/sampling.cpp +60 -50
  50. package/src/llama.cpp/docs/build.md +122 -7
  51. package/src/llama.cpp/examples/CMakeLists.txt +2 -32
  52. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
  54. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  55. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  56. package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
  57. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  58. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  59. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  60. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  61. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  62. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  64. package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
  65. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  66. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  67. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  68. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  69. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  70. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  71. package/src/llama.cpp/ggml/include/ggml.h +76 -106
  72. package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
  73. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  74. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  75. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  76. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  77. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  78. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  79. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  80. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  81. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  82. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  83. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
  84. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  85. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  86. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  87. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
  89. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  90. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
  93. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
  94. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
  95. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
  96. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  101. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  102. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
  103. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  104. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  105. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  106. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  107. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  108. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  109. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
  110. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  111. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
  112. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  113. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
  115. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
  116. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
  117. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  120. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
  121. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  122. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  123. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  124. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  136. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  137. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  138. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  140. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  141. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
  143. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
  144. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
  145. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
  146. package/src/llama.cpp/ggml/src/ggml.c +170 -265
  147. package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
  148. package/src/llama.cpp/include/llama.h +82 -22
  149. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  150. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  151. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  152. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  153. package/src/llama.cpp/requirements/requirements-all.txt +5 -3
  154. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  155. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  156. package/src/llama.cpp/src/CMakeLists.txt +4 -2
  157. package/src/llama.cpp/src/llama-adapter.cpp +43 -1
  158. package/src/llama.cpp/src/llama-arch.cpp +163 -17
  159. package/src/llama.cpp/src/llama-arch.h +16 -0
  160. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  161. package/src/llama.cpp/src/llama-batch.h +2 -1
  162. package/src/llama.cpp/src/llama-chat.cpp +91 -16
  163. package/src/llama.cpp/src/llama-chat.h +7 -2
  164. package/src/llama.cpp/src/llama-context.cpp +479 -575
  165. package/src/llama.cpp/src/llama-context.h +44 -33
  166. package/src/llama.cpp/src/llama-cparams.h +1 -0
  167. package/src/llama.cpp/src/llama-graph.cpp +209 -157
  168. package/src/llama.cpp/src/llama-graph.h +38 -14
  169. package/src/llama.cpp/src/llama-hparams.h +13 -0
  170. package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
  171. package/src/llama.cpp/src/llama-kv-cache.h +283 -171
  172. package/src/llama.cpp/src/llama-memory.h +12 -2
  173. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  174. package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
  175. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  176. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  177. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  178. package/src/llama.cpp/src/llama-model.cpp +1803 -330
  179. package/src/llama.cpp/src/llama-model.h +21 -2
  180. package/src/llama.cpp/src/llama-quant.cpp +33 -10
  181. package/src/llama.cpp/src/llama-sampling.cpp +25 -7
  182. package/src/llama.cpp/src/llama-vocab.cpp +86 -10
  183. package/src/llama.cpp/src/llama-vocab.h +6 -0
  184. package/src/llama.cpp/src/llama.cpp +15 -1
  185. package/src/llama.cpp/tests/CMakeLists.txt +52 -31
  186. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  187. package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
  188. package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
  189. package/src/llama.cpp/tests/test-chat.cpp +15 -3
  190. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  191. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  192. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  193. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  194. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  195. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  196. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  197. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  198. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  199. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  200. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  201. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  202. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  203. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  204. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
  205. package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
  206. package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
  207. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  208. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
  209. package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
  210. package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
  211. package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
  212. package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
  213. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  214. package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
  215. package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
  216. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  217. package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
  218. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  219. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
  220. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
  221. package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
  222. package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
  223. package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
  224. package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
  225. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  226. package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
  227. package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
  228. package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
  229. package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
  230. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  231. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  232. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  233. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  234. package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
  235. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  236. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  237. package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
  238. package/src/llama.cpp/examples/llava/clip.h +0 -118
  239. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  240. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  241. package/src/llama.cpp/examples/llava/llava.cpp +0 -574
  242. package/src/llama.cpp/examples/llava/llava.h +0 -49
  243. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  244. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
  245. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  246. package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
  247. package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
  248. package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
  249. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  250. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  251. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  252. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  253. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  254. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  255. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  256. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  257. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  258. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  259. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  260. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  261. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  262. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  263. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  264. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  265. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  266. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  267. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  268. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  269. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  270. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  271. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  272. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  273. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  274. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  275. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  276. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  277. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  278. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  279. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  280. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  281. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
@@ -3,7 +3,9 @@
3
3
  //
4
4
  #include <arm_neon.h>
5
5
  #include <assert.h>
6
+ #include <atomic>
6
7
  #include <cfloat>
8
+ #include <stdexcept>
7
9
  #include <stdint.h>
8
10
  #include <string.h>
9
11
  #if defined(__linux__)
@@ -34,8 +36,9 @@
34
36
  #include "ggml-common.h"
35
37
 
36
38
  struct ggml_kleidiai_context {
39
+ cpu_feature features;
37
40
  ggml_kleidiai_kernels * kernels;
38
- } static ctx = { NULL };
41
+ } static ctx = { CPU_FEATURE_NONE, NULL };
39
42
 
40
43
  static void init_kleidiai_context(void) {
41
44
 
@@ -47,18 +50,18 @@ static void init_kleidiai_context(void) {
47
50
  const char *env_var = getenv("GGML_KLEIDIAI_SME");
48
51
  int sme_enabled = 0;
49
52
 
50
- cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
51
- (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
52
- (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
53
+ ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
54
+ (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
55
+ (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
53
56
 
54
57
  if (env_var) {
55
58
  sme_enabled = atoi(env_var);
56
59
  }
57
60
 
58
61
  if (sme_enabled != 0) {
59
- features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
62
+ ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
60
63
  }
61
- ctx.kernels = ggml_kleidiai_select_kernels(features);
64
+ ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
62
65
  }
63
66
  ggml_critical_section_end();
64
67
  }
@@ -68,96 +71,275 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
68
71
  return tensor->ne[dim];
69
72
  }
70
73
 
74
+ template<typename Ret, typename Variant, typename... Args>
75
+ static Ret variant_call(const Variant & var, Args&&... args) {
76
+ return std::visit([&](auto&& func) -> Ret {
77
+ if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
78
+ return func(std::forward<Args>(args)...);
79
+ } else {
80
+ throw std::runtime_error("Invalid function type in variant_call");
81
+ }
82
+ }, var);
83
+ }
84
+
71
85
  namespace ggml::cpu::kleidiai {
86
+
87
+ static size_t round_down(size_t x, size_t y) {
88
+ return y == 0 ? x : x - (x % y);
89
+ }
90
+
91
+ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
92
+ size_t src_stride = rhs_stride / sizeof(uint16_t);
93
+ size_t dst_stride = n;
94
+
95
+ for (size_t k_idx = 0; k_idx < k; ++k_idx) {
96
+ for (size_t n_idx = 0; n_idx < n; ++n_idx) {
97
+ uint16_t v = *(src + k_idx + n_idx * src_stride);
98
+ *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
99
+ }
100
+ }
101
+ }
102
+
72
103
  class tensor_traits : public ggml::cpu::tensor_traits {
73
104
  bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
74
- GGML_ASSERT(ctx.kernels);
75
- kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
105
+ ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
106
+ GGML_ASSERT(kernels);
107
+ kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
76
108
 
77
109
  size_t k = op->src[0]->ne[0];
110
+ size_t n = op->src[0]->ne[1];
78
111
  size_t m = op->src[1]->ne[1];
79
112
 
80
113
  size_t mr = kernel->get_mr();
81
114
  size_t kr = kernel->get_kr();
82
115
  size_t sr = kernel->get_sr();
83
116
 
84
- size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
117
+ if (kernels->rhs_type == GGML_TYPE_Q4_0) {
118
+ size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
119
+ } else if (kernels->rhs_type == GGML_TYPE_F16) {
120
+ size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
121
+ variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
122
+ k * n * sizeof(float) + n * sizeof(float);
123
+ } else {
124
+ GGML_ASSERT(false);
125
+ }
85
126
 
86
127
  return true;
87
128
  }
88
129
 
130
+
89
131
  bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
90
132
  if (dst->op == GGML_OP_MUL_MAT) {
91
- const ggml_tensor * src0 = dst->src[0];
92
- const ggml_tensor * src1 = dst->src[1];
133
+ if (dst->src[0]->type == GGML_TYPE_Q4_0) {
134
+ return compute_forward_q4_0(params, dst);
135
+ } else if (dst->src[0]->type == GGML_TYPE_F16) {
136
+ return compute_forward_kv_cache(params, dst);
137
+ }
138
+ }
139
+ return false;
140
+ }
93
141
 
94
- GGML_TENSOR_BINARY_OP_LOCALS
142
+ bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
143
+ static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
95
144
 
96
- GGML_ASSERT(ctx.kernels);
97
- kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
98
- lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
145
+ const ggml_tensor * src0 = dst->src[0];
146
+ const ggml_tensor * src1 = dst->src[1];
99
147
 
100
- GGML_ASSERT(kernel);
148
+ GGML_TENSOR_BINARY_OP_LOCALS
101
149
 
102
- const int ith = params->ith;
103
- const int nth = params->nth;
150
+ ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
151
+ GGML_ASSERT(kernels);
104
152
 
105
- const size_t k = ne00;
106
- const size_t m = ne11;
107
- const size_t n = ne01;
153
+ kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
154
+ GGML_ASSERT(kernel);
108
155
 
109
- const size_t n_step = kernel->get_n_step();
110
- const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
111
- const size_t n_start = ith * num_n_per_thread;
156
+ const int nth = params->nth;
157
+ const int ith = params->ith;
112
158
 
113
- size_t n_to_process = num_n_per_thread;
114
- if ((n_start + n_to_process) > n) {
115
- n_to_process = n - n_start;
116
- }
159
+ const int64_t lhs_batch_size0 = ne12;
160
+ const int64_t rhs_batch_size0 = ne02;
161
+ const int64_t batch_size = rhs_batch_size0;
162
+
163
+ const int64_t r = lhs_batch_size0 / rhs_batch_size0;
164
+
165
+ const int64_t m = ne11 * r;
166
+ const int64_t n = ne01;
167
+ const int64_t k = ne00;
168
+
169
+ const size_t lhs_stride = src1->nb[1];
170
+ const size_t rhs_stride = src0->nb[1];
171
+ const size_t dst_stride = dst->nb[1];
172
+
173
+ const int64_t mr = static_cast<int64_t>(kernel->get_mr());
174
+ const int64_t nr = static_cast<int64_t>(kernel->get_nr());
175
+ const int64_t kr = static_cast<int64_t>(kernel->get_kr());
176
+ const int64_t sr = static_cast<int64_t>(kernel->get_sr());
177
+
178
+ const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
179
+ const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
180
+ const size_t kxn_size = k * n * sizeof(float);
181
+ const size_t bias_size = n * sizeof(float);
182
+
183
+ const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
184
+ GGML_ASSERT(wsize_required <= params->wsize);
185
+
186
+ uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
187
+ uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
188
+ uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
189
+ uint8_t * bias = rhs_kxn + kxn_size;
190
+
191
+ for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
192
+ const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
193
+ const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
194
+ uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
117
195
 
118
- const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
119
- uint8_t * lhs_packed = (uint8_t*)params->wdata;
120
- const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
121
-
122
- size_t mr = kernel->get_mr();
123
- size_t kr = kernel->get_kr();
124
- size_t sr = kernel->get_sr();
125
-
126
- // Calculate number of columns to be processed per thread
127
- const bool use_multithread = lhs_info->require_aligned_m_idx && m <= mr ? false : true;
128
- const size_t num_m_per_thread = use_multithread ? kai_roundup(m, nth) / nth : m;
129
- const size_t m_start = ith * num_m_per_thread;
130
- size_t m_to_process = num_m_per_thread;
131
- if ((m_start + m_to_process) > m) {
132
- m_to_process = m - m_start;
196
+ // LHS packing
197
+ {
198
+ const int64_t m_roundup_mr = kai_roundup(m, mr);
199
+ const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
200
+
201
+ if (ith < num_threads) {
202
+ const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
203
+ const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
204
+
205
+ const int64_t m_start = ith * num_m_per_thread0;
206
+ const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
207
+
208
+ const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
209
+ const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
210
+
211
+ const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
212
+ void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
213
+
214
+ variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
215
+ }
133
216
  }
134
217
 
135
- if(m_start < m) {
136
- // Transform LHS
137
- const size_t src_stride = src1->nb[1];
138
- const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
139
- const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
140
- void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
218
+ // RHS packing
219
+ if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
220
+ // First thread to reach this point handles RHS packing
221
+ memset(bias, 0, n * sizeof(float));
222
+ transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
223
+ reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
141
224
 
142
- lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
225
+ variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
226
+ rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
143
227
  }
144
228
 
145
229
  ggml_barrier(params->threadpool);
146
230
 
147
- // Perform the operation
148
- const size_t dst_stride = dst->nb[1];
149
- const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
150
- const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
151
- const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
152
- const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
153
- const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
154
- float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
155
-
156
- kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
157
- dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
158
- return true;
231
+ first_to_arrive.clear(std::memory_order_release);
232
+
233
+ // Perform the matmul
234
+ {
235
+ const int64_t m_to_process = m;
236
+ const int64_t m_start = 0;
237
+
238
+ const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
239
+ const int64_t num_threads = KAI_MIN(n / n_step, nth);
240
+
241
+ if (ith < num_threads) {
242
+ const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
243
+ const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
244
+
245
+ const int64_t n_start = ith * num_n_per_thread0;
246
+ const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
247
+
248
+ const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
249
+ const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
250
+ const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
251
+
252
+ const void * lhs_ptr = lhs_packed + lhs_packed_offset;
253
+ const void * rhs_ptr = rhs_packed + rhs_packed_offset;
254
+ float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
255
+
256
+ variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
257
+ }
258
+ }
259
+
260
+ if (batch_idx != batch_size - 1) {
261
+ // This barrier is necessary when the batch size is larger than 1. While processing a batch,
262
+ // the work data buffer (params->wdata) is used as temporary storage which means that only
263
+ // a single batch can be processed at any given time. No barrier is needed for the last
264
+ // batch since GGML inserts a barrier between the execution of every operator.
265
+ ggml_barrier(params->threadpool);
266
+ }
159
267
  }
160
- return false;
268
+
269
+ return true;
270
+ }
271
+
272
+ bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
273
+ const ggml_tensor * src0 = dst->src[0];
274
+ const ggml_tensor * src1 = dst->src[1];
275
+
276
+ GGML_TENSOR_BINARY_OP_LOCALS
277
+
278
+ ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
279
+ GGML_ASSERT(kernels);
280
+
281
+ kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
282
+ lhs_packing_info * lhs_info = &kernels->lhs_info;
283
+
284
+ GGML_ASSERT(kernel);
285
+
286
+ const int ith = params->ith;
287
+ const int nth = params->nth;
288
+
289
+ const size_t k = ne00;
290
+ const size_t m = ne11;
291
+ const size_t n = ne01;
292
+
293
+ size_t mr = kernel->get_mr();
294
+ size_t kr = kernel->get_kr();
295
+ size_t sr = kernel->get_sr();
296
+
297
+ const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
298
+ uint8_t * lhs_packed = (uint8_t*)params->wdata;
299
+ const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
300
+
301
+ const size_t n_step = kernel->get_n_step();
302
+ const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
303
+ const size_t n_start = ith * num_n_per_thread;
304
+
305
+ size_t n_to_process = num_n_per_thread;
306
+ if ((n_start + n_to_process) > n) {
307
+ n_to_process = n - n_start;
308
+ }
309
+
310
+ // Calculate number of columns to be processed per thread
311
+ const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
312
+ const size_t m_start = ith * num_m_per_thread;
313
+ size_t m_to_process = num_m_per_thread;
314
+ if ((m_start + m_to_process) > m) {
315
+ m_to_process = m - m_start;
316
+ }
317
+
318
+ if (m_start < m) {
319
+ // Transform LHS
320
+ const size_t src_stride = src1->nb[1];
321
+ const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
322
+ const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
323
+ void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
324
+
325
+ variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
326
+ }
327
+
328
+ ggml_barrier(params->threadpool);
329
+
330
+ // Perform the operation
331
+ const size_t dst_stride = dst->nb[1];
332
+ const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
333
+ const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
334
+ const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
335
+ const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
336
+ const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
337
+ float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
338
+
339
+ variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
340
+ sizeof(float), -FLT_MAX, FLT_MAX);
341
+
342
+ return true;
161
343
  }
162
344
 
163
345
  public:
@@ -170,13 +352,13 @@ public:
170
352
  size_t sr = ctx.kernels->gemm.get_sr();
171
353
 
172
354
  #ifndef NDEBUG
173
- const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
355
+ const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
174
356
  GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
175
357
  #endif
176
358
  struct kai_rhs_pack_qs4cxs1s0_param params;
177
359
  params.lhs_zero_point = 1;
178
360
  params.rhs_zero_point = 8;
179
- ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, &params);
361
+ variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
180
362
 
181
363
  return 0;
182
364
 
@@ -190,7 +372,7 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
190
372
  }
191
373
  } // namespace ggml::cpu::kleidiai
192
374
 
193
- GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
375
+ static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
194
376
  tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
195
377
 
196
378
  GGML_UNUSED(buffer);
@@ -239,12 +421,11 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
239
421
  namespace ggml::cpu::kleidiai {
240
422
  class extra_buffer_type : ggml::cpu::extra_buffer_type {
241
423
  bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
242
- if ( op->op == GGML_OP_MUL_MAT &&
243
- op->src[0]->type == GGML_TYPE_Q4_0 &&
244
- op->src[0]->buffer &&
245
- (ggml_n_dims(op->src[0]) == 2) &&
246
- op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
247
- ) {
424
+ if (op->op == GGML_OP_MUL_MAT &&
425
+ op->src[0]->type == GGML_TYPE_Q4_0 &&
426
+ op->src[0]->buffer &&
427
+ (ggml_n_dims(op->src[0]) == 2) &&
428
+ op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
248
429
  if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
249
430
  return false;
250
431
  }
@@ -261,6 +442,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
261
442
  if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
262
443
  return (ggml::cpu::tensor_traits *) op->src[0]->extra;
263
444
  }
445
+ else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
446
+ op->src[0]->op == GGML_OP_VIEW &&
447
+ (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
448
+ op->src[1]->ne[1] > 1) {
449
+ if ((op->src[0]->nb[0] != 2) ||
450
+ (op->src[1]->nb[0] != 4) ||
451
+ (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
452
+ (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
453
+ return nullptr;
454
+ }
455
+
456
+ return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
457
+ }
264
458
  }
265
459
  return nullptr;
266
460
  }