@fugood/llama.node 0.3.17 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. package/CMakeLists.txt +3 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +39 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +366 -19
  24. package/src/LlamaCompletionWorker.h +30 -10
  25. package/src/LlamaContext.cpp +213 -5
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
  29. package/src/llama.cpp/.github/workflows/build.yml +41 -762
  30. package/src/llama.cpp/.github/workflows/docker.yml +5 -2
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +12 -12
  33. package/src/llama.cpp/CMakeLists.txt +5 -17
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +31 -3
  37. package/src/llama.cpp/common/arg.cpp +48 -29
  38. package/src/llama.cpp/common/chat.cpp +128 -106
  39. package/src/llama.cpp/common/chat.h +2 -0
  40. package/src/llama.cpp/common/common.cpp +37 -1
  41. package/src/llama.cpp/common/common.h +18 -9
  42. package/src/llama.cpp/common/llguidance.cpp +1 -0
  43. package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
  44. package/src/llama.cpp/common/minja/minja.hpp +69 -36
  45. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  46. package/src/llama.cpp/common/regex-partial.h +56 -0
  47. package/src/llama.cpp/common/sampling.cpp +57 -50
  48. package/src/llama.cpp/examples/CMakeLists.txt +2 -23
  49. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
  50. package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
  51. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  52. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  53. package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
  54. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  55. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  56. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  57. package/src/llama.cpp/ggml/include/ggml.h +10 -7
  58. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  60. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  61. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
  62. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  63. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
  64. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
  65. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
  66. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  67. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  68. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  69. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
  71. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  73. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
  74. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
  75. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  76. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  77. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
  78. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
  80. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
  81. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
  82. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  83. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  84. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  85. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  86. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
  87. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  88. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
  89. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  90. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  91. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  92. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
  93. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
  94. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
  95. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
  96. package/src/llama.cpp/ggml/src/ggml.c +29 -20
  97. package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
  98. package/src/llama.cpp/include/llama.h +52 -11
  99. package/src/llama.cpp/requirements/requirements-all.txt +3 -3
  100. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  101. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  102. package/src/llama.cpp/src/llama-adapter.cpp +6 -0
  103. package/src/llama.cpp/src/llama-arch.cpp +3 -0
  104. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  105. package/src/llama.cpp/src/llama-batch.h +2 -1
  106. package/src/llama.cpp/src/llama-chat.cpp +17 -7
  107. package/src/llama.cpp/src/llama-chat.h +1 -0
  108. package/src/llama.cpp/src/llama-context.cpp +389 -501
  109. package/src/llama.cpp/src/llama-context.h +44 -32
  110. package/src/llama.cpp/src/llama-cparams.h +1 -0
  111. package/src/llama.cpp/src/llama-graph.cpp +20 -38
  112. package/src/llama.cpp/src/llama-graph.h +12 -8
  113. package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
  114. package/src/llama.cpp/src/llama-kv-cache.h +271 -85
  115. package/src/llama.cpp/src/llama-memory.h +11 -1
  116. package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
  117. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  118. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  119. package/src/llama.cpp/src/llama-model.cpp +316 -69
  120. package/src/llama.cpp/src/llama-model.h +8 -1
  121. package/src/llama.cpp/src/llama-quant.cpp +15 -13
  122. package/src/llama.cpp/src/llama-sampling.cpp +18 -6
  123. package/src/llama.cpp/src/llama-vocab.cpp +42 -4
  124. package/src/llama.cpp/src/llama-vocab.h +6 -0
  125. package/src/llama.cpp/src/llama.cpp +14 -0
  126. package/src/llama.cpp/tests/CMakeLists.txt +10 -2
  127. package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
  128. package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
  129. package/src/llama.cpp/tests/test-chat.cpp +3 -1
  130. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  131. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  132. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  133. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  134. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  135. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
  136. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  137. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
  138. package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
  139. package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
  140. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
  141. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
  142. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  143. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
  144. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  145. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
  146. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  147. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
  148. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
  149. package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
  150. package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
  151. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  152. package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
  153. package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
  154. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  155. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  156. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  157. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  158. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  159. package/src/llama.cpp/examples/llava/clip.h +0 -135
  160. package/src/llama.cpp/examples/llava/llava.cpp +0 -586
  161. package/src/llama.cpp/examples/llava/llava.h +0 -49
  162. package/src/llama.cpp/examples/llava/mtmd.h +0 -168
  163. package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
  164. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  165. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  166. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  167. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  168. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  169. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  170. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  171. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  172. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  173. /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
  174. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  175. /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
  176. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  177. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  178. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  179. /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
  180. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  181. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  182. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  183. /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
  184. /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
  185. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  186. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  187. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  188. /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
  189. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  190. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  191. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  192. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
  193. /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
@@ -3,7 +3,9 @@
3
3
  //
4
4
  #include <arm_neon.h>
5
5
  #include <assert.h>
6
+ #include <atomic>
6
7
  #include <cfloat>
8
+ #include <stdexcept>
7
9
  #include <stdint.h>
8
10
  #include <string.h>
9
11
  #if defined(__linux__)
@@ -34,8 +36,9 @@
34
36
  #include "ggml-common.h"
35
37
 
36
38
  struct ggml_kleidiai_context {
39
+ cpu_feature features;
37
40
  ggml_kleidiai_kernels * kernels;
38
- } static ctx = { NULL };
41
+ } static ctx = { CPU_FEATURE_NONE, NULL };
39
42
 
40
43
  static void init_kleidiai_context(void) {
41
44
 
@@ -47,18 +50,18 @@ static void init_kleidiai_context(void) {
47
50
  const char *env_var = getenv("GGML_KLEIDIAI_SME");
48
51
  int sme_enabled = 0;
49
52
 
50
- cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
51
- (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
52
- (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
53
+ ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
54
+ (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
55
+ (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
53
56
 
54
57
  if (env_var) {
55
58
  sme_enabled = atoi(env_var);
56
59
  }
57
60
 
58
61
  if (sme_enabled != 0) {
59
- features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
62
+ ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
60
63
  }
61
- ctx.kernels = ggml_kleidiai_select_kernels(features);
64
+ ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
62
65
  }
63
66
  ggml_critical_section_end();
64
67
  }
@@ -68,95 +71,275 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
68
71
  return tensor->ne[dim];
69
72
  }
70
73
 
74
+ template<typename Ret, typename Variant, typename... Args>
75
+ static Ret variant_call(const Variant & var, Args&&... args) {
76
+ return std::visit([&](auto&& func) -> Ret {
77
+ if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
78
+ return func(std::forward<Args>(args)...);
79
+ } else {
80
+ throw std::runtime_error("Invalid function type in variant_call");
81
+ }
82
+ }, var);
83
+ }
84
+
71
85
  namespace ggml::cpu::kleidiai {
86
+
87
+ static size_t round_down(size_t x, size_t y) {
88
+ return y == 0 ? x : x - (x % y);
89
+ }
90
+
91
+ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
92
+ size_t src_stride = rhs_stride / sizeof(uint16_t);
93
+ size_t dst_stride = n;
94
+
95
+ for (size_t k_idx = 0; k_idx < k; ++k_idx) {
96
+ for (size_t n_idx = 0; n_idx < n; ++n_idx) {
97
+ uint16_t v = *(src + k_idx + n_idx * src_stride);
98
+ *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
99
+ }
100
+ }
101
+ }
102
+
72
103
  class tensor_traits : public ggml::cpu::tensor_traits {
73
104
  bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
74
- GGML_ASSERT(ctx.kernels);
75
- kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
105
+ ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
106
+ GGML_ASSERT(kernels);
107
+ kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
76
108
 
77
109
  size_t k = op->src[0]->ne[0];
110
+ size_t n = op->src[0]->ne[1];
78
111
  size_t m = op->src[1]->ne[1];
79
112
 
80
113
  size_t mr = kernel->get_mr();
81
114
  size_t kr = kernel->get_kr();
82
115
  size_t sr = kernel->get_sr();
83
116
 
84
- size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
117
+ if (kernels->rhs_type == GGML_TYPE_Q4_0) {
118
+ size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
119
+ } else if (kernels->rhs_type == GGML_TYPE_F16) {
120
+ size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
121
+ variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
122
+ k * n * sizeof(float) + n * sizeof(float);
123
+ } else {
124
+ GGML_ASSERT(false);
125
+ }
85
126
 
86
127
  return true;
87
128
  }
88
129
 
130
+
89
131
  bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
90
132
  if (dst->op == GGML_OP_MUL_MAT) {
91
- const ggml_tensor * src0 = dst->src[0];
92
- const ggml_tensor * src1 = dst->src[1];
133
+ if (dst->src[0]->type == GGML_TYPE_Q4_0) {
134
+ return compute_forward_q4_0(params, dst);
135
+ } else if (dst->src[0]->type == GGML_TYPE_F16) {
136
+ return compute_forward_kv_cache(params, dst);
137
+ }
138
+ }
139
+ return false;
140
+ }
93
141
 
94
- GGML_TENSOR_BINARY_OP_LOCALS
142
+ bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
143
+ static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
95
144
 
96
- GGML_ASSERT(ctx.kernels);
97
- kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
98
- lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
145
+ const ggml_tensor * src0 = dst->src[0];
146
+ const ggml_tensor * src1 = dst->src[1];
99
147
 
100
- GGML_ASSERT(kernel);
148
+ GGML_TENSOR_BINARY_OP_LOCALS
101
149
 
102
- const int ith = params->ith;
103
- const int nth = params->nth;
150
+ ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
151
+ GGML_ASSERT(kernels);
104
152
 
105
- const size_t k = ne00;
106
- const size_t m = ne11;
107
- const size_t n = ne01;
153
+ kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
154
+ GGML_ASSERT(kernel);
108
155
 
109
- const size_t n_step = kernel->get_n_step();
110
- const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
111
- const size_t n_start = ith * num_n_per_thread;
156
+ const int nth = params->nth;
157
+ const int ith = params->ith;
112
158
 
113
- size_t n_to_process = num_n_per_thread;
114
- if ((n_start + n_to_process) > n) {
115
- n_to_process = n - n_start;
116
- }
159
+ const int64_t lhs_batch_size0 = ne12;
160
+ const int64_t rhs_batch_size0 = ne02;
161
+ const int64_t batch_size = rhs_batch_size0;
162
+
163
+ const int64_t r = lhs_batch_size0 / rhs_batch_size0;
164
+
165
+ const int64_t m = ne11 * r;
166
+ const int64_t n = ne01;
167
+ const int64_t k = ne00;
168
+
169
+ const size_t lhs_stride = src1->nb[1];
170
+ const size_t rhs_stride = src0->nb[1];
171
+ const size_t dst_stride = dst->nb[1];
172
+
173
+ const int64_t mr = static_cast<int64_t>(kernel->get_mr());
174
+ const int64_t nr = static_cast<int64_t>(kernel->get_nr());
175
+ const int64_t kr = static_cast<int64_t>(kernel->get_kr());
176
+ const int64_t sr = static_cast<int64_t>(kernel->get_sr());
177
+
178
+ const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
179
+ const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
180
+ const size_t kxn_size = k * n * sizeof(float);
181
+ const size_t bias_size = n * sizeof(float);
182
+
183
+ const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
184
+ GGML_ASSERT(wsize_required <= params->wsize);
185
+
186
+ uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
187
+ uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
188
+ uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
189
+ uint8_t * bias = rhs_kxn + kxn_size;
190
+
191
+ for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
192
+ const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
193
+ const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
194
+ uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
117
195
 
118
- const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
119
- uint8_t * lhs_packed = (uint8_t*)params->wdata;
120
- const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
196
+ // LHS packing
197
+ {
198
+ const int64_t m_roundup_mr = kai_roundup(m, mr);
199
+ const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
121
200
 
122
- size_t mr = kernel->get_mr();
123
- size_t kr = kernel->get_kr();
124
- size_t sr = kernel->get_sr();
201
+ if (ith < num_threads) {
202
+ const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
203
+ const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
125
204
 
126
- // Calculate number of columns to be processed per thread
127
- const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
128
- const size_t m_start = ith * num_m_per_thread;
129
- size_t m_to_process = num_m_per_thread;
130
- if ((m_start + m_to_process) > m) {
131
- m_to_process = m - m_start;
205
+ const int64_t m_start = ith * num_m_per_thread0;
206
+ const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
207
+
208
+ const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
209
+ const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
210
+
211
+ const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
212
+ void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
213
+
214
+ variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
215
+ }
132
216
  }
133
217
 
134
- if(m_start < m) {
135
- // Transform LHS
136
- const size_t src_stride = src1->nb[1];
137
- const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
138
- const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
139
- void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
218
+ // RHS packing
219
+ if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
220
+ // First thread to reach this point handles RHS packing
221
+ memset(bias, 0, n * sizeof(float));
222
+ transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
223
+ reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
140
224
 
141
- lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
225
+ variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
226
+ rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
142
227
  }
143
228
 
144
229
  ggml_barrier(params->threadpool);
145
230
 
146
- // Perform the operation
147
- const size_t dst_stride = dst->nb[1];
148
- const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
149
- const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
150
- const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
151
- const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
152
- const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
153
- float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
154
-
155
- kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
156
- dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
157
- return true;
231
+ first_to_arrive.clear(std::memory_order_release);
232
+
233
+ // Perform the matmul
234
+ {
235
+ const int64_t m_to_process = m;
236
+ const int64_t m_start = 0;
237
+
238
+ const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
239
+ const int64_t num_threads = KAI_MIN(n / n_step, nth);
240
+
241
+ if (ith < num_threads) {
242
+ const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
243
+ const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
244
+
245
+ const int64_t n_start = ith * num_n_per_thread0;
246
+ const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
247
+
248
+ const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
249
+ const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
250
+ const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
251
+
252
+ const void * lhs_ptr = lhs_packed + lhs_packed_offset;
253
+ const void * rhs_ptr = rhs_packed + rhs_packed_offset;
254
+ float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
255
+
256
+ variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
257
+ }
258
+ }
259
+
260
+ if (batch_idx != batch_size - 1) {
261
+ // This barrier is necessary when the batch size is larger than 1. While processing a batch,
262
+ // the work data buffer (params->wdata) is used as temporary storage which means that only
263
+ // a single batch can be processed at any given time. No barrier is needed for the last
264
+ // batch since GGML inserts a barrier between the execution of every operator.
265
+ ggml_barrier(params->threadpool);
266
+ }
158
267
  }
159
- return false;
268
+
269
+ return true;
270
+ }
271
+
272
+ bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
273
+ const ggml_tensor * src0 = dst->src[0];
274
+ const ggml_tensor * src1 = dst->src[1];
275
+
276
+ GGML_TENSOR_BINARY_OP_LOCALS
277
+
278
+ ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
279
+ GGML_ASSERT(kernels);
280
+
281
+ kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
282
+ lhs_packing_info * lhs_info = &kernels->lhs_info;
283
+
284
+ GGML_ASSERT(kernel);
285
+
286
+ const int ith = params->ith;
287
+ const int nth = params->nth;
288
+
289
+ const size_t k = ne00;
290
+ const size_t m = ne11;
291
+ const size_t n = ne01;
292
+
293
+ size_t mr = kernel->get_mr();
294
+ size_t kr = kernel->get_kr();
295
+ size_t sr = kernel->get_sr();
296
+
297
+ const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
298
+ uint8_t * lhs_packed = (uint8_t*)params->wdata;
299
+ const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
300
+
301
+ const size_t n_step = kernel->get_n_step();
302
+ const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
303
+ const size_t n_start = ith * num_n_per_thread;
304
+
305
+ size_t n_to_process = num_n_per_thread;
306
+ if ((n_start + n_to_process) > n) {
307
+ n_to_process = n - n_start;
308
+ }
309
+
310
+ // Calculate number of columns to be processed per thread
311
+ const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
312
+ const size_t m_start = ith * num_m_per_thread;
313
+ size_t m_to_process = num_m_per_thread;
314
+ if ((m_start + m_to_process) > m) {
315
+ m_to_process = m - m_start;
316
+ }
317
+
318
+ if (m_start < m) {
319
+ // Transform LHS
320
+ const size_t src_stride = src1->nb[1];
321
+ const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
322
+ const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
323
+ void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
324
+
325
+ variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
326
+ }
327
+
328
+ ggml_barrier(params->threadpool);
329
+
330
+ // Perform the operation
331
+ const size_t dst_stride = dst->nb[1];
332
+ const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
333
+ const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
334
+ const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
335
+ const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
336
+ const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
337
+ float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
338
+
339
+ variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
340
+ sizeof(float), -FLT_MAX, FLT_MAX);
341
+
342
+ return true;
160
343
  }
161
344
 
162
345
  public:
@@ -169,13 +352,13 @@ public:
169
352
  size_t sr = ctx.kernels->gemm.get_sr();
170
353
 
171
354
  #ifndef NDEBUG
172
- const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
355
+ const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
173
356
  GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
174
357
  #endif
175
358
  struct kai_rhs_pack_qs4cxs1s0_param params;
176
359
  params.lhs_zero_point = 1;
177
360
  params.rhs_zero_point = 8;
178
- ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, &params);
361
+ variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
179
362
 
180
363
  return 0;
181
364
 
@@ -189,7 +372,7 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
189
372
  }
190
373
  } // namespace ggml::cpu::kleidiai
191
374
 
192
- GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
375
+ static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
193
376
  tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
194
377
 
195
378
  GGML_UNUSED(buffer);
@@ -238,12 +421,11 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
238
421
  namespace ggml::cpu::kleidiai {
239
422
  class extra_buffer_type : ggml::cpu::extra_buffer_type {
240
423
  bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
241
- if ( op->op == GGML_OP_MUL_MAT &&
242
- op->src[0]->type == GGML_TYPE_Q4_0 &&
243
- op->src[0]->buffer &&
244
- (ggml_n_dims(op->src[0]) == 2) &&
245
- op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
246
- ) {
424
+ if (op->op == GGML_OP_MUL_MAT &&
425
+ op->src[0]->type == GGML_TYPE_Q4_0 &&
426
+ op->src[0]->buffer &&
427
+ (ggml_n_dims(op->src[0]) == 2) &&
428
+ op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
247
429
  if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
248
430
  return false;
249
431
  }
@@ -260,6 +442,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
260
442
  if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
261
443
  return (ggml::cpu::tensor_traits *) op->src[0]->extra;
262
444
  }
445
+ else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
446
+ op->src[0]->op == GGML_OP_VIEW &&
447
+ (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
448
+ op->src[1]->ne[1] > 1) {
449
+ if ((op->src[0]->nb[0] != 2) ||
450
+ (op->src[1]->nb[0] != 4) ||
451
+ (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
452
+ (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
453
+ return nullptr;
454
+ }
455
+
456
+ return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
457
+ }
263
458
  }
264
459
  return nullptr;
265
460
  }