@fugood/llama.node 0.3.17 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. package/CMakeLists.txt +3 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +39 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +366 -19
  24. package/src/LlamaCompletionWorker.h +30 -10
  25. package/src/LlamaContext.cpp +213 -5
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
  29. package/src/llama.cpp/.github/workflows/build.yml +41 -762
  30. package/src/llama.cpp/.github/workflows/docker.yml +5 -2
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +12 -12
  33. package/src/llama.cpp/CMakeLists.txt +5 -17
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +31 -3
  37. package/src/llama.cpp/common/arg.cpp +48 -29
  38. package/src/llama.cpp/common/chat.cpp +128 -106
  39. package/src/llama.cpp/common/chat.h +2 -0
  40. package/src/llama.cpp/common/common.cpp +37 -1
  41. package/src/llama.cpp/common/common.h +18 -9
  42. package/src/llama.cpp/common/llguidance.cpp +1 -0
  43. package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
  44. package/src/llama.cpp/common/minja/minja.hpp +69 -36
  45. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  46. package/src/llama.cpp/common/regex-partial.h +56 -0
  47. package/src/llama.cpp/common/sampling.cpp +57 -50
  48. package/src/llama.cpp/examples/CMakeLists.txt +2 -23
  49. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
  50. package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
  51. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  52. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  53. package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
  54. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  55. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  56. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  57. package/src/llama.cpp/ggml/include/ggml.h +10 -7
  58. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  60. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  61. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
  62. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  63. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
  64. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
  65. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
  66. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  67. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  68. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  69. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
  71. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  73. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
  74. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
  75. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  76. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  77. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
  78. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
  80. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
  81. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
  82. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  83. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  84. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  85. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  86. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
  87. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  88. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
  89. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  90. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  91. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  92. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
  93. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
  94. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
  95. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
  96. package/src/llama.cpp/ggml/src/ggml.c +29 -20
  97. package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
  98. package/src/llama.cpp/include/llama.h +52 -11
  99. package/src/llama.cpp/requirements/requirements-all.txt +3 -3
  100. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  101. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  102. package/src/llama.cpp/src/llama-adapter.cpp +6 -0
  103. package/src/llama.cpp/src/llama-arch.cpp +3 -0
  104. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  105. package/src/llama.cpp/src/llama-batch.h +2 -1
  106. package/src/llama.cpp/src/llama-chat.cpp +17 -7
  107. package/src/llama.cpp/src/llama-chat.h +1 -0
  108. package/src/llama.cpp/src/llama-context.cpp +389 -501
  109. package/src/llama.cpp/src/llama-context.h +44 -32
  110. package/src/llama.cpp/src/llama-cparams.h +1 -0
  111. package/src/llama.cpp/src/llama-graph.cpp +20 -38
  112. package/src/llama.cpp/src/llama-graph.h +12 -8
  113. package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
  114. package/src/llama.cpp/src/llama-kv-cache.h +271 -85
  115. package/src/llama.cpp/src/llama-memory.h +11 -1
  116. package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
  117. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  118. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  119. package/src/llama.cpp/src/llama-model.cpp +316 -69
  120. package/src/llama.cpp/src/llama-model.h +8 -1
  121. package/src/llama.cpp/src/llama-quant.cpp +15 -13
  122. package/src/llama.cpp/src/llama-sampling.cpp +18 -6
  123. package/src/llama.cpp/src/llama-vocab.cpp +42 -4
  124. package/src/llama.cpp/src/llama-vocab.h +6 -0
  125. package/src/llama.cpp/src/llama.cpp +14 -0
  126. package/src/llama.cpp/tests/CMakeLists.txt +10 -2
  127. package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
  128. package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
  129. package/src/llama.cpp/tests/test-chat.cpp +3 -1
  130. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  131. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  132. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  133. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  134. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  135. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
  136. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  137. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
  138. package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
  139. package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
  140. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
  141. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
  142. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  143. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
  144. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  145. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
  146. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  147. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
  148. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
  149. package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
  150. package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
  151. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  152. package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
  153. package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
  154. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  155. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  156. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  157. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  158. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  159. package/src/llama.cpp/examples/llava/clip.h +0 -135
  160. package/src/llama.cpp/examples/llava/llava.cpp +0 -586
  161. package/src/llama.cpp/examples/llava/llava.h +0 -49
  162. package/src/llama.cpp/examples/llava/mtmd.h +0 -168
  163. package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
  164. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  165. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  166. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  167. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  168. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  169. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  170. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  171. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  172. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  173. /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
  174. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  175. /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
  176. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  177. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  178. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  179. /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
  180. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  181. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  182. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  183. /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
  184. /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
  185. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  186. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  187. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  188. /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
  189. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  190. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  191. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  192. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
  193. /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
@@ -214,7 +214,7 @@ add_library(ggml
214
214
  target_link_libraries(ggml PUBLIC ggml-base)
215
215
 
216
216
  if (CMAKE_SYSTEM_NAME MATCHES "Linux")
217
- target_link_libraries(ggml PRIVATE dl stdc++fs)
217
+ target_link_libraries(ggml PRIVATE dl)
218
218
  endif()
219
219
 
220
220
  function(ggml_add_backend_library backend)
@@ -816,7 +816,10 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor *
816
816
  static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) {
817
817
  size_t node_size = 0;
818
818
  if (!node->data && !node->view_src) {
819
- GGML_ASSERT(talloc->buffer_id >= 0); // prevent segfault when misusing the API
819
+ // If we previously had data but don't now then reallocate
820
+ if (talloc->buffer_id < 0) {
821
+ return false;
822
+ }
820
823
  node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
821
824
  }
822
825
  return talloc->size_max >= node_size;
@@ -56,7 +56,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
56
56
  return SIZE_MAX;
57
57
  }
58
58
 
59
- size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
59
+ size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
60
60
  // get_alloc_size is optional, defaults to ggml_nbytes
61
61
  if (buft->iface.get_alloc_size) {
62
62
  size_t size = buft->iface.get_alloc_size(buft, tensor);
@@ -152,7 +152,7 @@ size_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) {
152
152
  return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer));
153
153
  }
154
154
 
155
- size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
155
+ size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor) {
156
156
  return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor);
157
157
  }
158
158
 
@@ -674,6 +674,8 @@ struct ggml_backend_sched {
674
674
  char * context_buffer;
675
675
  size_t context_buffer_size;
676
676
 
677
+ bool op_offload;
678
+
677
679
  int debug;
678
680
  };
679
681
 
@@ -766,7 +768,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
766
768
  if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
767
769
  int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
768
770
  // check if a backend with higher prio wants to offload the op
769
- if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
771
+ if (sched->op_offload && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
770
772
  for (int b = 0; b < src_backend_id; b++) {
771
773
  if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
772
774
  SET_CAUSE(tensor, "1.off");
@@ -1109,7 +1111,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
1109
1111
 
1110
1112
  const int node_backend_id = tensor_backend_id(node);
1111
1113
 
1112
- assert(node_backend_id != -1); // all nodes should be assigned by now
1114
+ assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
1113
1115
 
1114
1116
  // check if we should start a new split based on the sources of the current node
1115
1117
  bool need_new_split = false;
@@ -1452,7 +1454,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
1452
1454
  ggml_backend_buffer_type_t * bufts,
1453
1455
  int n_backends,
1454
1456
  size_t graph_size,
1455
- bool parallel) {
1457
+ bool parallel,
1458
+ bool op_offload) {
1456
1459
  GGML_ASSERT(n_backends > 0);
1457
1460
  GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
1458
1461
  GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU);
@@ -1497,6 +1500,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
1497
1500
  }
1498
1501
 
1499
1502
  sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
1503
+ sched->op_offload = op_offload;
1500
1504
 
1501
1505
  ggml_backend_sched_reset(sched);
1502
1506
 
@@ -385,9 +385,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
385
385
 
386
386
  # Fetch KleidiAI sources:
387
387
  include(FetchContent)
388
- set(KLEIDIAI_COMMIT_TAG "v1.5.0")
388
+ set(KLEIDIAI_COMMIT_TAG "v1.6.0")
389
389
  set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
390
- set(KLEIDIAI_ARCHIVE_MD5 "ea22e1aefb800e9bc8c74d91633cc58e")
390
+ set(KLEIDIAI_ARCHIVE_MD5 "75b4ad68f25ab673dcc01065e5a0b05f")
391
391
 
392
392
  if (POLICY CMP0135)
393
393
  cmake_policy(SET CMP0135 NEW)
@@ -428,6 +428,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
428
428
  ${KLEIDIAI_SRC}/kai/ukernels/
429
429
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/
430
430
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
431
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
431
432
  ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
432
433
 
433
434
  set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
@@ -438,17 +439,19 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
438
439
  string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
439
440
  string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
440
441
 
441
- set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
442
+ set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP})
442
443
 
443
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
444
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
445
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
446
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
444
+ list(APPEND GGML_KLEIDIAI_SOURCES
445
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
446
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
447
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
448
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
447
449
 
448
450
  if (NOT DOTPROD_ENABLED MATCHES -1)
449
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
450
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
451
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
451
+ list(APPEND GGML_KLEIDIAI_SOURCES
452
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
453
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
454
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
452
455
  endif()
453
456
 
454
457
  if (NOT I8MM_ENABLED MATCHES -1)
@@ -456,9 +459,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
456
459
  endif()
457
460
 
458
461
  if (NOT SME_ENABLED MATCHES -1)
459
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
460
- list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
461
- set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
462
+ list(APPEND GGML_KLEIDIAI_SOURCES
463
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
464
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
465
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
466
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
467
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c)
468
+ set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
462
469
  endif()
463
470
 
464
471
  set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
@@ -72,8 +72,6 @@ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wro
72
72
 
73
73
  #if defined(__GNUC__)
74
74
  #pragma GCC diagnostic ignored "-Woverlength-strings"
75
- #elif defined(_MSC_VER)
76
- #pragma warning(disable: 4244 4267) // possible loss of data
77
75
  #endif
78
76
 
79
77
  #define UNUSED GGML_UNUSED
@@ -20,12 +20,6 @@
20
20
  #define GROUP_MAX_EPS_IQ1_M 1e-7f
21
21
  #define GROUP_MAX_EPS_IQ1_S 1e-12f
22
22
 
23
- #if defined(_MSC_VER)
24
- // disable "possible loss of data" to avoid warnings for hundreds of casts
25
- // we should just be careful :)
26
- #pragma warning(disable: 4244 4267)
27
- #endif
28
-
29
23
  #define UNUSED GGML_UNUSED
30
24
 
31
25
  // some compilers don't provide _mm256_set_m128i, e.g. gcc 7
@@ -6596,7 +6590,118 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
6596
6590
  }
6597
6591
 
6598
6592
  *s = hsum_float_8(acc);
6593
+ #elif defined(__VXE__) || defined(__VXE2__)
6594
+ uint32_t aux[3];
6595
+ uint32_t utmp[4];
6596
+
6597
+ const int32x4_t v_z = vec_splat_s32(0);
6598
+ const uint8x16_t v_3m = vec_splat_u8(0x03);
6599
+
6600
+ const uint8x16_t v_0c = vec_splat_u8(1);
6601
+ const uint8x16_t v_1c = vec_sl(v_0c, 1);
6602
+ const uint8x16_t v_2c = vec_sl(v_0c, 2);
6603
+ const uint8x16_t v_3c = vec_sl(v_0c, 3);
6604
+
6605
+ uint8x16_t q3h[4];
6606
+ uint8x16_t q3b[2];
6607
+ int8x16_t q3bytes[4];
6608
+ int8x16_t q8bytes[4];
6609
+ uint8x16_t qhbits[2];
6610
+
6611
+ float sum = 0;
6612
+
6613
+ for (int i = 0; i < nb; ++i) {
6614
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6615
+
6616
+ const uint8_t * restrict x0l = x[i].qs;
6617
+ const uint8_t * restrict x0h = x[i].hmask;
6618
+ const int8_t * restrict y0 = y[i].qs;
6619
+
6620
+ qhbits[0] = vec_xl(0 , x0h);
6621
+ qhbits[1] = vec_xl(16, x0h);
6622
+
6623
+ int32_t isum = 0;
6624
+
6625
+ memcpy(aux, x[i].scales, 12);
6626
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
6627
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
6628
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
6629
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
6630
+
6631
+ int8_t * scale = (int8_t *)utmp;
6632
+ for (int j = 0; j < 16; ++j) scale[j] -= 32;
6633
+
6634
+ for (int j = 0; j < QK_K/128; ++j) {
6635
+ int32x4_t isum0, isum1, isum2, isum3;
6636
+
6637
+ q3b[0] = vec_xl(0 , x0l);
6638
+ q3b[1] = vec_xl(16, x0l);
6639
+ x0l += 32;
6640
+
6641
+ q8bytes[0] = vec_xl(0 , y0);
6642
+ q8bytes[1] = vec_xl(16 , y0);
6643
+ q8bytes[2] = vec_xl(32 , y0);
6644
+ q8bytes[3] = vec_xl(48 , y0);
6645
+ q8bytes[4] = vec_xl(64 , y0);
6646
+ q8bytes[5] = vec_xl(80 , y0);
6647
+ q8bytes[6] = vec_xl(96 , y0);
6648
+ q8bytes[7] = vec_xl(112, y0);
6649
+ y0 += 128;
6650
+
6651
+ q3h[0] = vec_sl(vec_andc(v_0c, qhbits[0]), 2);
6652
+ q3h[1] = vec_sl(vec_andc(v_0c, qhbits[1]), 2);
6653
+ q3h[2] = vec_sl(vec_andc(v_1c, qhbits[0]), 1);
6654
+ q3h[3] = vec_sl(vec_andc(v_1c, qhbits[1]), 1);
6655
+
6656
+ q3bytes[0] = vec_sub((int8x16_t)vec_and(q3b[0], v_3m), (int8x16_t)q3h[0]);
6657
+ q3bytes[1] = vec_sub((int8x16_t)vec_and(q3b[1], v_3m), (int8x16_t)q3h[1]);
6658
+ q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 2), v_3m), (int8x16_t)q3h[2]);
6659
+ q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 2), v_3m), (int8x16_t)q3h[3]);
6660
+
6661
+ isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[0]);
6662
+ isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[1]);
6663
+ isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[2]);
6664
+ isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[3]);
6665
+
6666
+ isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
6667
+ isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
6668
+ isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
6669
+ isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
6670
+
6671
+ scale += 4;
6672
+
6673
+ q3h[0] = vec_andc(v_2c, qhbits[0]);
6674
+ q3h[1] = vec_andc(v_2c, qhbits[1]);
6675
+ q3h[2] = vec_sr(vec_andc(v_3c, qhbits[0]), 1);
6676
+ q3h[3] = vec_sr(vec_andc(v_3c, qhbits[1]), 1);
6677
+
6678
+ q3bytes[0] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 4), v_3m), (int8x16_t)q3h[0]);
6679
+ q3bytes[1] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 4), v_3m), (int8x16_t)q3h[1]);
6680
+ q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 6), v_3m), (int8x16_t)q3h[2]);
6681
+ q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 6), v_3m), (int8x16_t)q3h[3]);
6682
+
6683
+ isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[4]);
6684
+ isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[5]);
6685
+ isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]);
6686
+ isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]);
6687
+
6688
+ isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
6689
+ isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
6690
+ isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
6691
+ isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
6692
+
6693
+ scale += 4;
6694
+
6695
+ if (j == 0) {
6696
+ qhbits[0] = vec_sr(qhbits[0], 4);
6697
+ qhbits[1] = vec_sr(qhbits[1], 4);
6698
+ }
6699
+ }
6700
+
6701
+ sum += d * isum;
6702
+ }
6599
6703
 
6704
+ *s = sum;
6600
6705
  #else
6601
6706
  // scalar version
6602
6707
  // This function is written like this so the compiler can manage to vectorize most of it
@@ -8414,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8414
8519
 
8415
8520
  void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
8416
8521
  assert(n % QK_K == 0);
8522
+ #ifdef __ARM_FEATURE_MATMUL_INT8
8523
+ assert((nrc == 2) || (nrc == 1));
8524
+ #else
8417
8525
  assert(nrc == 1);
8526
+ #endif
8418
8527
  UNUSED(nrc);
8419
8528
  UNUSED(bx);
8420
8529
  UNUSED(by);
@@ -8425,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8425
8534
 
8426
8535
  const int nb = n / QK_K;
8427
8536
 
8537
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
8538
+ if (nrc == 2) {
8539
+ const block_q6_K * GGML_RESTRICT x0 = x;
8540
+ const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
8541
+ const block_q8_K * GGML_RESTRICT y0 = y;
8542
+ const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
8543
+
8544
+ float32x4_t vfsum = vdupq_n_f32(0.0f);
8545
+
8546
+ for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
8547
+ const uint8_t * GGML_RESTRICT ql0 = x0->ql;
8548
+ const uint8_t * GGML_RESTRICT ql1 = x1->ql;
8549
+ const uint8_t * GGML_RESTRICT qh0 = x0->qh;
8550
+ const uint8_t * GGML_RESTRICT qh1 = x1->qh;
8551
+ const int8_t * GGML_RESTRICT qy0 = y0->qs;
8552
+ const int8_t * GGML_RESTRICT qy1 = y1->qs;
8553
+
8554
+ const uint8x16_t mone = vdupq_n_u8(0x30);
8555
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
8556
+
8557
+ int32x4_t visum = vdupq_n_s32(0);
8558
+
8559
+ // process 8 blocks per iteration, totally 16 blocks
8560
+ for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
8561
+ int8x16_t vx0[8], vx1[8];
8562
+
8563
+ // de-quantize vx0[8]
8564
+ {
8565
+ const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
8566
+ const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
8567
+
8568
+ uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
8569
+ uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
8570
+ uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
8571
+ uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
8572
+
8573
+ vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
8574
+ vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
8575
+ vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
8576
+ vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
8577
+
8578
+ q6h_0 = vandq_u8(mone, qh_bits.val[0]);
8579
+ q6h_1 = vandq_u8(mone, qh_bits.val[1]);
8580
+ q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
8581
+ q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
8582
+
8583
+ vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
8584
+ vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
8585
+ vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
8586
+ vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
8587
+ }
8588
+
8589
+ // de-quantize vx1[8]
8590
+ {
8591
+ const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
8592
+ const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
8593
+
8594
+ uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
8595
+ uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
8596
+ uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
8597
+ uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
8598
+
8599
+ vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
8600
+ vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
8601
+ vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
8602
+ vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
8603
+
8604
+ q6h_0 = vandq_u8(mone, qh_bits.val[0]);
8605
+ q6h_1 = vandq_u8(mone, qh_bits.val[1]);
8606
+ q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
8607
+ q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
8608
+
8609
+ vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
8610
+ vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
8611
+ vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
8612
+ vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
8613
+ }
8614
+
8615
+ // process 16 elements (one block with same scale) per iteration
8616
+ // - vx = concat(ql, qh) - 32
8617
+ // - r1,r2,r3,r4 = smmla(vx, vy)
8618
+ for (int k = 0; k < 8; ++k) {
8619
+ const int blk = j * 8 + k;
8620
+
8621
+ const int8x16_t vy0 = vld1q_s8(qy0);
8622
+ const int8x16_t vy1 = vld1q_s8(qy1);
8623
+ qy0 += 16;
8624
+ qy1 += 16;
8625
+
8626
+ const int32x4_t block_scale = {
8627
+ x0->scales[blk],
8628
+ x0->scales[blk],
8629
+ x1->scales[blk],
8630
+ x1->scales[blk],
8631
+ };
8632
+
8633
+ // calculate four results at once with outer product
8634
+ const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
8635
+ const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
8636
+ const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
8637
+ const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
8638
+ int32x4_t vr = vdupq_n_s32(0);
8639
+ vr = vmmlaq_s32(vr, vx_l, vy_l);
8640
+ vr = vmmlaq_s32(vr, vx_h, vy_h);
8641
+
8642
+ // apply block scale, will NOT overflow
8643
+ // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
8644
+ visum = vmlaq_s32(visum, vr, block_scale);
8645
+ }
8646
+ }
8647
+
8648
+ // adjust bias, apply superblock scale
8649
+ {
8650
+ int32_t bias[4];
8651
+ #ifdef __ARM_FEATURE_SVE
8652
+ const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
8653
+ const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
8654
+ const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
8655
+ const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
8656
+ const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
8657
+ const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
8658
+ const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
8659
+ const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
8660
+ const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
8661
+ const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
8662
+ const svint64_t zero = svdup_n_s64(0);
8663
+ bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
8664
+ svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
8665
+ bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
8666
+ svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
8667
+ bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
8668
+ svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
8669
+ bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
8670
+ svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
8671
+ #else
8672
+ // NEON doesn't support int16 dot product, fallback to separated mul and add
8673
+ const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
8674
+ const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
8675
+
8676
+ int8x16_t scales_s8 = vld1q_s8(x0->scales);
8677
+ const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
8678
+ scales_s8 = vld1q_s8(x1->scales);
8679
+ const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
8680
+
8681
+ int32x4_t prod;
8682
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
8683
+ vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
8684
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
8685
+ vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
8686
+ bias[0] = vaddvq_s32(prod);
8687
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
8688
+ vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
8689
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
8690
+ vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
8691
+ bias[1] = vaddvq_s32(prod);
8692
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
8693
+ vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
8694
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
8695
+ vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
8696
+ bias[2] = vaddvq_s32(prod);
8697
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
8698
+ vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
8699
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
8700
+ vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
8701
+ bias[3] = vaddvq_s32(prod);
8702
+
8703
+ #endif
8704
+ const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
8705
+
8706
+ const float32x4_t superblock_scale = {
8707
+ GGML_FP16_TO_FP32(x0->d) * y0->d,
8708
+ GGML_FP16_TO_FP32(x0->d) * y1->d,
8709
+ GGML_FP16_TO_FP32(x1->d) * y0->d,
8710
+ GGML_FP16_TO_FP32(x1->d) * y1->d,
8711
+ };
8712
+
8713
+ visum = vsubq_s32(visum, vibias);
8714
+ vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
8715
+ }
8716
+ }
8717
+
8718
+ // vfsum = ABCD -> ACBD
8719
+ // AC -> s, BD -> (s+bs)
8720
+ vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
8721
+ vst1_f32(s, vget_low_f32 (vfsum));
8722
+ vst1_f32(s + bs, vget_high_f32(vfsum));
8723
+
8724
+ return;
8725
+ }
8726
+ #endif
8727
+
8428
8728
  #ifdef __ARM_FEATURE_SVE
8429
8729
  const int vector_length = ggml_cpu_get_sve_cnt()*8;
8430
8730
  float sum = 0;
@@ -50,19 +50,6 @@
50
50
  #include "llamafile/sgemm.h"
51
51
  #endif
52
52
 
53
- #if defined(_MSC_VER)
54
- // disable "possible loss of data" to avoid hundreds of casts
55
- // we should just be careful :)
56
- #pragma warning(disable: 4244 4267)
57
-
58
- // disable POSIX deprecation warnings
59
- // these functions are never going away, anyway
60
- #pragma warning(disable: 4996)
61
-
62
- // unreachable code because of multiple instances of code after GGML_ABORT
63
- #pragma warning(disable: 4702)
64
- #endif
65
-
66
53
  // Note: once we move threading into a separate C++ file
67
54
  // will use std::hardware_destructive_interference_size instead of hardcoding it here
68
55
  // and we'll use C++ attribute syntax.
@@ -295,7 +282,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
295
282
  .from_float = quantize_row_q6_K,
296
283
  .vec_dot = ggml_vec_dot_q6_K_q8_K,
297
284
  .vec_dot_type = GGML_TYPE_Q8_K,
285
+ #if defined (__ARM_FEATURE_MATMUL_INT8)
286
+ .nrows = 2,
287
+ #else
298
288
  .nrows = 1,
289
+ #endif
299
290
  },
300
291
  [GGML_TYPE_IQ2_XXS] = {
301
292
  .from_float = NULL,
@@ -11,24 +11,26 @@
11
11
  #include <vector>
12
12
 
13
13
  #ifdef GGML_USE_CPU_HBM
14
- #include "ggml-cpu-hbm.h"
14
+ # include "ggml-cpu-hbm.h"
15
15
  #endif
16
16
 
17
17
  #ifdef GGML_USE_CPU_KLEIDIAI
18
- #include "kleidiai/kleidiai.h"
19
- #endif
20
-
21
- #if defined(__APPLE__)
22
- #include <sys/types.h>
23
- #include <sys/sysctl.h>
18
+ # include "kleidiai/kleidiai.h"
24
19
  #endif
25
20
 
26
21
  #if defined(_WIN32)
27
- #define WIN32_LEAN_AND_MEAN
28
- #ifndef NOMINMAX
29
- #define NOMINMAX
22
+ # define WIN32_LEAN_AND_MEAN
23
+ # ifndef NOMINMAX
24
+ # define NOMINMAX
25
+ # endif
26
+ # include <windows.h>
27
+ #else
28
+ # include <unistd.h>
30
29
  #endif
31
- #include <windows.h>
30
+
31
+ #if defined(__APPLE__)
32
+ # include <sys/sysctl.h>
33
+ # include <sys/types.h>
32
34
  #endif
33
35
 
34
36
  // ggml-backend interface
@@ -70,8 +72,10 @@ static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_ty
70
72
  }
71
73
 
72
74
  static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {
73
- for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
74
- if (extra && extra == buft) return true;
75
+ for (auto * extra : ggml_backend_cpu_get_extra_buffers_type()) {
76
+ if (extra && extra == buft) {
77
+ return true;
78
+ }
75
79
  }
76
80
  return false;
77
81
  }
@@ -330,9 +334,18 @@ static const char * ggml_backend_cpu_device_get_description(ggml_backend_dev_t d
330
334
  }
331
335
 
332
336
  static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
333
- // TODO
334
- *free = 0;
335
- *total = 0;
337
+ #ifdef _WIN32
338
+ MEMORYSTATUSEX status;
339
+ status.dwLength = sizeof(status);
340
+ GlobalMemoryStatusEx(&status);
341
+ *total = status.ullTotalPhys;
342
+ *free = status.ullAvailPhys;
343
+ #else
344
+ long pages = sysconf(_SC_PHYS_PAGES);
345
+ long page_size = sysconf(_SC_PAGE_SIZE);
346
+ *total = pages * page_size;
347
+ *free = *total;
348
+ #endif
336
349
 
337
350
  GGML_UNUSED(dev);
338
351
  }