@fugood/llama.node 0.3.6 → 0.3.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/README.md +17 -2
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +3 -1
  19. package/lib/index.js +16 -1
  20. package/lib/index.ts +16 -0
  21. package/package.json +1 -1
  22. package/src/EmbeddingWorker.cpp +4 -3
  23. package/src/LlamaCompletionWorker.cpp +4 -2
  24. package/src/LlamaContext.cpp +61 -6
  25. package/src/LlamaContext.h +1 -0
  26. package/src/common.hpp +6 -11
  27. package/src/llama.cpp/.github/workflows/build.yml +19 -17
  28. package/src/llama.cpp/.github/workflows/docker.yml +77 -30
  29. package/src/llama.cpp/.github/workflows/editorconfig.yml +3 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +22 -3
  31. package/src/llama.cpp/CMakeLists.txt +49 -24
  32. package/src/llama.cpp/common/arg.cpp +82 -26
  33. package/src/llama.cpp/common/arg.h +3 -0
  34. package/src/llama.cpp/common/common.cpp +192 -72
  35. package/src/llama.cpp/common/common.h +51 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +12 -12
  37. package/src/llama.cpp/common/ngram-cache.h +2 -2
  38. package/src/llama.cpp/common/sampling.cpp +11 -6
  39. package/src/llama.cpp/common/speculative.cpp +18 -15
  40. package/src/llama.cpp/docs/build.md +2 -0
  41. package/src/llama.cpp/examples/batched/batched.cpp +9 -7
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +3 -3
  43. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +10 -8
  44. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +11 -8
  45. package/src/llama.cpp/examples/cvector-generator/mean.hpp +1 -1
  46. package/src/llama.cpp/examples/cvector-generator/pca.hpp +1 -1
  47. package/src/llama.cpp/examples/embedding/embedding.cpp +8 -7
  48. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +7 -6
  49. package/src/llama.cpp/examples/export-lora/export-lora.cpp +8 -7
  50. package/src/llama.cpp/examples/gguf/gguf.cpp +10 -6
  51. package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +1 -0
  52. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +8 -7
  53. package/src/llama.cpp/examples/gritlm/gritlm.cpp +13 -10
  54. package/src/llama.cpp/examples/imatrix/imatrix.cpp +13 -12
  55. package/src/llama.cpp/examples/infill/infill.cpp +23 -24
  56. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +44 -13
  57. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -6
  58. package/src/llama.cpp/examples/llava/clip.cpp +4 -2
  59. package/src/llama.cpp/examples/llava/llava-cli.cpp +9 -6
  60. package/src/llama.cpp/examples/llava/llava.cpp +2 -2
  61. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +8 -4
  62. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +11 -8
  63. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -7
  64. package/src/llama.cpp/examples/lookup/lookup-create.cpp +4 -9
  65. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +3 -7
  66. package/src/llama.cpp/examples/lookup/lookup.cpp +5 -6
  67. package/src/llama.cpp/examples/main/main.cpp +51 -29
  68. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -6
  69. package/src/llama.cpp/examples/passkey/passkey.cpp +7 -5
  70. package/src/llama.cpp/examples/perplexity/perplexity.cpp +37 -23
  71. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -14
  72. package/src/llama.cpp/examples/retrieval/retrieval.cpp +8 -8
  73. package/src/llama.cpp/examples/rpc/rpc-server.cpp +12 -0
  74. package/src/llama.cpp/examples/run/CMakeLists.txt +1 -1
  75. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +1351 -0
  76. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +114 -0
  77. package/src/llama.cpp/examples/run/run.cpp +175 -61
  78. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -25
  79. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -0
  80. package/src/llama.cpp/examples/server/httplib.h +1295 -409
  81. package/src/llama.cpp/examples/server/server.cpp +387 -181
  82. package/src/llama.cpp/examples/server/tests/requirements.txt +1 -0
  83. package/src/llama.cpp/examples/server/utils.hpp +170 -58
  84. package/src/llama.cpp/examples/simple/simple.cpp +9 -8
  85. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +16 -12
  86. package/src/llama.cpp/examples/speculative/speculative.cpp +22 -23
  87. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +8 -12
  88. package/src/llama.cpp/examples/tokenize/tokenize.cpp +17 -5
  89. package/src/llama.cpp/examples/tts/tts.cpp +64 -23
  90. package/src/llama.cpp/ggml/CMakeLists.txt +5 -21
  91. package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
  92. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -0
  93. package/src/llama.cpp/ggml/include/ggml.h +36 -145
  94. package/src/llama.cpp/ggml/include/gguf.h +202 -0
  95. package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
  96. package/src/llama.cpp/ggml/src/ggml-alloc.c +5 -0
  97. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +0 -1
  98. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +79 -49
  99. package/src/llama.cpp/ggml/src/ggml-backend.cpp +5 -2
  100. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +33 -23
  101. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +57 -72
  102. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +87 -2
  103. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +335 -66
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -2
  105. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1090 -378
  106. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +2 -2
  107. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +1 -0
  108. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
  109. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  110. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +3 -1
  111. package/src/llama.cpp/ggml/src/ggml-impl.h +11 -16
  112. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +16 -0
  113. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +6 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +154 -35
  115. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  116. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +9 -3
  117. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +18 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  119. package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +1 -2
  120. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +3 -2
  121. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +1 -2
  122. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
  123. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +48 -48
  124. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +24 -24
  125. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -164
  126. package/src/llama.cpp/ggml/src/ggml-sycl/gla.cpp +105 -0
  127. package/src/llama.cpp/ggml/src/ggml-sycl/gla.hpp +8 -0
  128. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +3 -3
  129. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +1 -2
  130. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -2
  131. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +1 -2
  132. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +7 -5
  133. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +1 -2
  134. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +74 -4
  135. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +314 -116
  136. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -2
  137. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +9 -3
  138. package/src/llama.cpp/ggml/src/ggml.c +117 -1327
  139. package/src/llama.cpp/ggml/src/gguf.cpp +1329 -0
  140. package/src/llama.cpp/include/llama-cpp.h +6 -1
  141. package/src/llama.cpp/include/llama.h +138 -75
  142. package/src/llama.cpp/src/CMakeLists.txt +13 -1
  143. package/src/llama.cpp/src/llama-adapter.cpp +347 -0
  144. package/src/llama.cpp/src/llama-adapter.h +74 -0
  145. package/src/llama.cpp/src/llama-arch.cpp +1487 -0
  146. package/src/llama.cpp/src/llama-arch.h +400 -0
  147. package/src/llama.cpp/src/llama-batch.cpp +368 -0
  148. package/src/llama.cpp/src/llama-batch.h +88 -0
  149. package/src/llama.cpp/src/llama-chat.cpp +578 -0
  150. package/src/llama.cpp/src/llama-chat.h +52 -0
  151. package/src/llama.cpp/src/llama-context.cpp +1775 -0
  152. package/src/llama.cpp/src/llama-context.h +128 -0
  153. package/src/llama.cpp/src/llama-cparams.cpp +1 -0
  154. package/src/llama.cpp/src/llama-cparams.h +37 -0
  155. package/src/llama.cpp/src/llama-grammar.cpp +5 -4
  156. package/src/llama.cpp/src/llama-grammar.h +3 -1
  157. package/src/llama.cpp/src/llama-hparams.cpp +71 -0
  158. package/src/llama.cpp/src/llama-hparams.h +139 -0
  159. package/src/llama.cpp/src/llama-impl.cpp +167 -0
  160. package/src/llama.cpp/src/llama-impl.h +16 -136
  161. package/src/llama.cpp/src/llama-kv-cache.cpp +718 -0
  162. package/src/llama.cpp/src/llama-kv-cache.h +218 -0
  163. package/src/llama.cpp/src/llama-mmap.cpp +589 -0
  164. package/src/llama.cpp/src/llama-mmap.h +67 -0
  165. package/src/llama.cpp/src/llama-model-loader.cpp +1124 -0
  166. package/src/llama.cpp/src/llama-model-loader.h +167 -0
  167. package/src/llama.cpp/src/llama-model.cpp +3953 -0
  168. package/src/llama.cpp/src/llama-model.h +370 -0
  169. package/src/llama.cpp/src/llama-quant.cpp +934 -0
  170. package/src/llama.cpp/src/llama-quant.h +1 -0
  171. package/src/llama.cpp/src/llama-sampling.cpp +147 -32
  172. package/src/llama.cpp/src/llama-sampling.h +3 -19
  173. package/src/llama.cpp/src/llama-vocab.cpp +1832 -575
  174. package/src/llama.cpp/src/llama-vocab.h +97 -142
  175. package/src/llama.cpp/src/llama.cpp +7160 -20314
  176. package/src/llama.cpp/src/unicode.cpp +8 -3
  177. package/src/llama.cpp/tests/CMakeLists.txt +2 -0
  178. package/src/llama.cpp/tests/test-autorelease.cpp +3 -3
  179. package/src/llama.cpp/tests/test-backend-ops.cpp +370 -59
  180. package/src/llama.cpp/tests/test-chat-template.cpp +162 -125
  181. package/src/llama.cpp/tests/test-gguf.cpp +222 -187
  182. package/src/llama.cpp/tests/test-model-load-cancel.cpp +1 -1
  183. package/src/llama.cpp/tests/test-sampling.cpp +0 -1
  184. package/src/llama.cpp/tests/test-tokenizer-0.cpp +4 -4
  185. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +9 -7
  186. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +8 -6
@@ -82,8 +82,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
82
82
  if (MSVC AND NOT CMAKE_C_COMPILER_ID STREQUAL "Clang")
83
83
  message(FATAL_ERROR "MSVC is not supported for ARM, use clang")
84
84
  else()
85
- check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
86
- if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
85
+ check_cxx_compiler_flag(-mfp16-format=ieee GGML_COMPILER_SUPPORTS_FP16_FORMAT_I3E)
86
+ if (NOT "${GGML_COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
87
87
  list(APPEND ARCH_FLAGS -mfp16-format=ieee)
88
88
  endif()
89
89
 
@@ -106,28 +106,28 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
106
106
  message(STATUS "ARM -mcpu not found, -mcpu=native will be used")
107
107
  endif()
108
108
 
109
- set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
110
109
  include(CheckCXXSourceRuns)
111
110
 
112
- set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+dotprod")
113
- check_cxx_source_runs(
114
- "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }"
115
- GGML_COMPILER_SUPPORT_DOTPROD)
116
- if (GGML_COMPILER_SUPPORT_DOTPROD)
117
- set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+dotprod")
118
- endif()
111
+ function(check_arm_feature tag code)
112
+ set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
113
+ set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
114
+ check_cxx_source_runs(
115
+ "${code}"
116
+ GGML_MACHINE_SUPPORTS_${tag}
117
+ )
118
+ if (GGML_MACHINE_SUPPORTS_${tag})
119
+ set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
120
+ else()
121
+ set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
122
+ endif()
123
+ set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
124
+ endfunction()
119
125
 
120
- set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+i8mm")
121
- check_cxx_source_runs(
122
- "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }"
123
- GGML_COMPILER_SUPPORT_I8MM)
124
- if (GGML_COMPILER_SUPPORT_I8MM)
125
- set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+i8mm")
126
- endif()
126
+ check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
127
+ check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
128
+ check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
127
129
 
128
- set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
129
130
  list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
130
-
131
131
  else()
132
132
  if (GGML_CPU_ARM_ARCH)
133
133
  list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH})
@@ -135,14 +135,20 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
135
135
  endif()
136
136
 
137
137
  # show enabled features
138
+ if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
139
+ set(FEAT_INPUT_FILE "NUL")
140
+ else()
141
+ set(FEAT_INPUT_FILE "/dev/null")
142
+ endif()
143
+
138
144
  execute_process(
139
145
  COMMAND ${CMAKE_C_COMPILER} ${ARCH_FLAGS} -dM -E -
140
- INPUT_FILE "/dev/null"
146
+ INPUT_FILE ${FEAT_INPUT_FILE}
141
147
  OUTPUT_VARIABLE ARM_FEATURE
142
148
  RESULT_VARIABLE ARM_FEATURE_RESULT
143
149
  )
144
150
  if (ARM_FEATURE_RESULT)
145
- message(FATAL_ERROR "Failed to get ARM features")
151
+ message(WARNING "Failed to get ARM features")
146
152
  else()
147
153
  foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC)
148
154
  string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
@@ -209,8 +215,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
209
215
  list(APPEND ARCH_DEFINITIONS GGML_SSE42)
210
216
  endif()
211
217
  if (GGML_AVX_VNNI)
212
- # MSVC generates AVX512 with AVX-VNNI intrinsics even with /arch:AVX2
213
- #list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)
218
+ list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)
214
219
  endif()
215
220
  else ()
216
221
  if (GGML_NATIVE)
@@ -317,6 +322,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
317
322
  target_compile_definitions(${GGML_CPU_NAME} PRIVATE ${ARCH_DEFINITIONS})
318
323
 
319
324
  if (GGML_BACKEND_DL)
325
+ if (GGML_NATIVE)
326
+ # the feature check relies on ARCH_DEFINITIONS, but it is not set with GGML_NATIVE
327
+ message(FATAL_ERROR "GGML_NATIVE is not compatible with GGML_BACKEND_DL, consider using GGML_CPU_ALL_VARIANTS")
328
+ endif()
329
+
320
330
  # The feature detection code is compiled as a separate target so that
321
331
  # it can be built without the architecture flags
322
332
  # Since multiple variants of the CPU backend may be included in the same
@@ -194,9 +194,12 @@ static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
194
194
  }
195
195
 
196
196
  static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
197
- #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
197
+ #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
198
198
  const __m256i zero = _mm256_setzero_si256();
199
199
  return _mm256_dpbusd_epi32(zero, ax, sy);
200
+ #elif defined(__AVXVNNI__)
201
+ const __m256i zero = _mm256_setzero_si256();
202
+ return _mm256_dpbusd_avx_epi32(zero, ax, sy);
200
203
  #else
201
204
  // Perform multiplication and create 16-bit values
202
205
  const __m256i dot = _mm256_maddubs_epi16(ax, sy);
@@ -564,21 +567,21 @@ static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
564
567
 
565
568
  #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
566
569
  if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
567
- const block_q4_0x4 * b_ptr = (const block_q4_0x4 *)vx;
570
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
568
571
 
569
572
  for (int c = 0; c < nc; c += ncols_interleaved) {
570
- const block_q8_0 * a_ptr = (const block_q8_0 *)vy;
573
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
571
574
  float32x4_t acc = vdupq_n_f32(0);
572
575
  for (int b = 0; b < nb; b++) {
573
- int8x16_t b0 = vld1q_s8((const int8_t *)b_ptr->qs);
574
- int8x16_t b1 = vld1q_s8((const int8_t *)b_ptr->qs + 16);
575
- int8x16_t b2 = vld1q_s8((const int8_t *)b_ptr->qs + 32);
576
- int8x16_t b3 = vld1q_s8((const int8_t *)b_ptr->qs + 48);
577
- float16x4_t bd = vld1_f16((const __fp16 *)b_ptr->d);
576
+ int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
577
+ int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
578
+ int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
579
+ int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
580
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
578
581
 
579
582
  int8x16_t a0 = vld1q_s8(a_ptr->qs);
580
583
  int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
581
- float16x4_t ad = vld1_dup_f16((const __fp16 *)&a_ptr->d);
584
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
582
585
 
583
586
  int32x4_t ret = vdupq_n_s32(0);
584
587
 
@@ -647,72 +650,52 @@ static void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
647
650
  UNUSED(ncols_interleaved);
648
651
  UNUSED(blocklen);
649
652
 
650
- #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
651
- if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
652
- const void * b_ptr = vx;
653
- const void * a_ptr = vy;
654
- float * res_ptr = s;
653
+ #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
654
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
655
+ const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
655
656
 
656
- __asm__ __volatile__(
657
- "movi v2.16b, #0x4\n"
658
- "movi v1.16b, #0xf0\n"
659
- "add %x[b_ptr], %x[b_ptr], #0x8\n"
660
- "1:" // Column loop
661
- "add x23, %x[a_ptr], #0x2\n"
662
- "movi v0.16b, #0x0\n"
663
- "mov x22, %x[nb]\n"
664
- "2:" // Block loop
665
- "ldr q31, [%x[b_ptr], #0x0]\n"
666
- "ldr q30, [%x[b_ptr], #0x10]\n"
667
- "mov x21, x23\n"
668
- "movi v29.4s, #0x0\n"
669
- "ldr q28, [%x[b_ptr], #0x20]\n"
670
- "ldr q27, [%x[b_ptr], #0x30]\n"
671
- "movi v26.4s, #0x0\n"
672
- "sub x20, x23, #0x2\n"
673
- "ld1r { v25.8h }, [x20]\n"
674
- "ldr q24, [%x[b_ptr], #-0x8]\n"
675
- "sub x22, x22, #0x1\n"
676
- "add x23, x23, #0x22\n"
677
- "ld1r { v23.2d }, [x21], #0x8\n"
678
- "sshl v22.16b, v31.16b, v2.16b\n"
679
- "sshl v16.16b, v30.16b, v2.16b\n"
680
- "add %x[b_ptr], %x[b_ptr], #0x48\n"
681
- "ld1r { v21.2d }, [x21], #0x8\n"
682
- "sshl v20.16b, v28.16b, v2.16b\n"
683
- "sshl v19.16b, v27.16b, v2.16b\n"
684
- "ld1r { v18.2d }, [x21], #0x8\n"
685
- "ld1r { v17.2d }, [x21], #0x8\n"
686
- "and v31.16b, v31.16b, v1.16b\n"
687
- "and v30.16b, v30.16b, v1.16b\n"
688
- ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n"
689
- ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n"
690
- "and v28.16b, v28.16b, v1.16b\n"
691
- "and v27.16b, v27.16b, v1.16b\n"
692
- "fcvtl v25.4s, v25.4h\n"
693
- "fcvtl v16.4s, v24.4h\n"
694
- ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n"
695
- ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n"
696
- "fmul v16.4s, v16.4s, v25.4s\n"
697
- ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n"
698
- ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n"
699
- ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n"
700
- ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n"
701
- "addp v29.4s, v29.4s, v26.4s\n"
702
- "scvtf v29.4s, v29.4s, #0x4\n"
703
- "fmla v0.4s, v29.4s, v16.4s\n"
704
- "cbnz x22, 2b\n"
705
- "sub %x[nc], %x[nc], #0x4\n"
706
- "str q0, [%x[res_ptr], #0x0]\n"
707
- "add %x[res_ptr], %x[res_ptr], #0x10\n"
708
- "cbnz %x[nc], 1b\n"
709
- : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
710
- : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
711
- : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
712
- );
657
+ for (int c = 0; c < nc; c += ncols_interleaved) {
658
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
659
+ float32x4_t acc = vdupq_n_f32(0);
660
+ for (int b = 0; b < nb; b++) {
661
+ int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
662
+ int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
663
+ int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
664
+ int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
665
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
666
+
667
+ int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs);
668
+ int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1);
669
+ int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2);
670
+ int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3);
671
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
672
+
673
+ int32x4_t ret0 = vdupq_n_s32(0);
674
+ int32x4_t ret1 = vdupq_n_s32(0);
675
+
676
+ ret0 = vdotq_s32(ret0, b0 << 4, a0);
677
+ ret1 = vdotq_s32(ret1, b1 << 4, a0);
678
+ ret0 = vdotq_s32(ret0, b2 << 4, a1);
679
+ ret1 = vdotq_s32(ret1, b3 << 4, a1);
680
+
681
+ ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2);
682
+ ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2);
683
+ ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3);
684
+ ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3);
685
+
686
+ int32x4_t ret = vpaddq_s32(ret0, ret1);
687
+
688
+ acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
689
+ vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
690
+ a_ptr++;
691
+ b_ptr++;
692
+ }
693
+ vst1q_f32(s, acc);
694
+ s += ncols_interleaved;
695
+ }
713
696
  return;
714
697
  }
715
- #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
698
+ #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
716
699
  float sumf[4];
717
700
  int sumi;
718
701
 
@@ -4186,6 +4169,8 @@ static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(g
4186
4169
  buffer->buft = buft;
4187
4170
  buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
4188
4171
  buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor;
4172
+ buffer->iface.get_tensor = nullptr;
4173
+ buffer->iface.cpy_tensor = nullptr;
4189
4174
  return buffer;
4190
4175
  }
4191
4176
 
@@ -103,10 +103,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
103
103
  }
104
104
 
105
105
  static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
106
- #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
106
+ #if defined(__AVX512VNNI__) && defined(__AVX512VL__)
107
107
  const __m256i zero = _mm256_setzero_si256();
108
108
  const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
109
109
  return _mm256_cvtepi32_ps(summed_pairs);
110
+ #elif defined(__AVXVNNI__)
111
+ const __m256i zero = _mm256_setzero_si256();
112
+ const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
113
+ return _mm256_cvtepi32_ps(summed_pairs);
110
114
  #else
111
115
  // Perform multiplication and create 16-bit values
112
116
  const __m256i dot = _mm256_maddubs_epi16(ax, sy);
@@ -5569,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5569
5573
 
5570
5574
  uint32_t utmp[4];
5571
5575
 
5572
- #ifdef __ARM_NEON
5576
+ #ifdef __ARM_FEATURE_SVE
5577
+ float sumf = 0;
5578
+ for (int i = 0; i < nb; ++i) {
5579
+
5580
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5581
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5582
+
5583
+ const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
5584
+
5585
+ memcpy(utmp, x[i].scales, K_SCALE_SIZE);
5586
+
5587
+ uint32x2_t mins8 = { 0 };
5588
+ mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
5589
+ mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
5590
+
5591
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
5592
+ utmp[0] &= kmask1;
5593
+
5594
+ const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
5595
+ const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
5596
+ vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
5597
+ sumf -= dmin * vaddvq_s32(prod);
5598
+
5599
+ const uint8_t * scales = (const uint8_t *)utmp;
5600
+
5601
+ const uint8_t * restrict q4 = x[i].qs;
5602
+ const int8_t * restrict q8 = y[i].qs;
5603
+
5604
+ const int vector_length = ggml_cpu_get_sve_cnt()*8;
5605
+ const svuint8_t m4b = svdup_n_u8(0xf);
5606
+ const svint32_t mzero = svdup_n_s32(0);
5607
+ svint32_t sumi1 = svdup_n_s32(0);
5608
+ svint32_t sumi1_1 = svdup_n_s32(0);
5609
+ svint32_t sumi1_2 = svdup_n_s32(0);
5610
+ svint32_t sumi2 = svdup_n_s32(0);
5611
+ svint32_t sumi2_1 = svdup_n_s32(0);
5612
+ svint32_t sumi2_2 = svdup_n_s32(0);
5613
+ switch (vector_length) {
5614
+ case 128:
5615
+ {
5616
+ for (int j = 0; j < QK_K/64; ++j) {
5617
+ svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
5618
+ svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
5619
+ sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
5620
+ q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
5621
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
5622
+ sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
5623
+
5624
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
5625
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
5626
+ sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
5627
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
5628
+ q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
5629
+ sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
5630
+ q4 += 32;
5631
+ }
5632
+ sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
5633
+ sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
5634
+ sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
5635
+ } break;
5636
+ case 256:
5637
+ case 512:
5638
+ {
5639
+ for (int j = 0; j < QK_K/64; ++j) {
5640
+ const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
5641
+ svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
5642
+ svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
5643
+ sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
5644
+
5645
+ q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
5646
+ q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
5647
+ sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
5648
+ }
5649
+ sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
5650
+ } break;
5651
+ default:
5652
+ assert(false && "Unsupported vector length");
5653
+ break;
5654
+ }
5655
+ }
5656
+ *s = sumf;
5657
+ #elif __ARM_NEON
5573
5658
  const uint8x16_t m4b = vdupq_n_u8(0xf);
5574
5659
  const int32x4_t mzero = vdupq_n_s32(0);
5575
5660