@fugood/llama.node 0.3.3 → 0.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/CMakeLists.txt +5 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +29 -1
  17. package/package.json +1 -1
  18. package/src/EmbeddingWorker.cpp +15 -5
  19. package/src/EmbeddingWorker.h +2 -1
  20. package/src/LlamaCompletionWorker.cpp +17 -1
  21. package/src/LlamaContext.cpp +86 -18
  22. package/src/LlamaContext.h +2 -0
  23. package/src/llama.cpp/.github/workflows/build.yml +197 -159
  24. package/src/llama.cpp/.github/workflows/docker.yml +5 -8
  25. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  27. package/src/llama.cpp/CMakeLists.txt +11 -6
  28. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  29. package/src/llama.cpp/cmake/common.cmake +33 -0
  30. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  31. package/src/llama.cpp/common/CMakeLists.txt +6 -2
  32. package/src/llama.cpp/common/arg.cpp +426 -245
  33. package/src/llama.cpp/common/common.cpp +143 -80
  34. package/src/llama.cpp/common/common.h +81 -24
  35. package/src/llama.cpp/common/sampling.cpp +53 -19
  36. package/src/llama.cpp/common/sampling.h +22 -1
  37. package/src/llama.cpp/common/speculative.cpp +274 -0
  38. package/src/llama.cpp/common/speculative.h +28 -0
  39. package/src/llama.cpp/docs/build.md +101 -148
  40. package/src/llama.cpp/examples/CMakeLists.txt +32 -13
  41. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +5 -4
  43. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  46. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  47. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  48. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  49. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  50. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  52. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  55. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  57. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  59. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
  61. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/infill/infill.cpp +1 -1
  63. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  64. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
  65. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  66. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  67. package/src/llama.cpp/examples/llava/clip.cpp +262 -66
  68. package/src/llama.cpp/examples/llava/clip.h +8 -2
  69. package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
  70. package/src/llama.cpp/examples/llava/llava.cpp +46 -19
  71. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
  72. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  73. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  75. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  76. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
  77. package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
  78. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/main/main.cpp +9 -5
  80. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  81. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  83. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  84. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  87. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  88. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  89. package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
  90. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  91. package/src/llama.cpp/examples/run/run.cpp +911 -0
  92. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
  94. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
  95. package/src/llama.cpp/examples/server/server.cpp +1758 -886
  96. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  97. package/src/llama.cpp/examples/server/utils.hpp +94 -304
  98. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  99. package/src/llama.cpp/examples/simple/simple.cpp +4 -0
  100. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
  101. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
  102. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
  104. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  106. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
  108. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  109. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  110. package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
  111. package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
  112. package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
  113. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  114. package/src/llama.cpp/ggml/include/ggml.h +106 -24
  115. package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
  116. package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
  117. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
  118. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
  119. package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
  120. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
  121. package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
  122. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
  123. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  124. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  125. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
  126. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  127. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  128. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  129. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  130. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  131. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  132. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  133. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  134. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  135. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
  136. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  137. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  138. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
  139. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
  140. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
  141. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  142. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
  143. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
  151. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
  152. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
  153. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  155. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
  156. package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
  157. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
  158. package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
  159. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
  160. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
  161. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
  162. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  163. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  164. package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
  165. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
  167. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
  169. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
  172. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  173. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  174. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
  175. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
  176. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  177. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
  178. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  179. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  180. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
  181. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
  182. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
  183. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  184. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  185. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
  187. package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
  188. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
  189. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
  190. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
  191. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
  192. package/src/llama.cpp/ggml/src/ggml.c +367 -207
  193. package/src/llama.cpp/include/llama-cpp.h +25 -0
  194. package/src/llama.cpp/include/llama.h +26 -19
  195. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  196. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  197. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  198. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  199. package/src/llama.cpp/src/CMakeLists.txt +2 -7
  200. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  201. package/src/llama.cpp/src/llama-grammar.h +2 -5
  202. package/src/llama.cpp/src/llama-sampling.cpp +35 -90
  203. package/src/llama.cpp/src/llama-vocab.cpp +6 -1
  204. package/src/llama.cpp/src/llama.cpp +1748 -640
  205. package/src/llama.cpp/src/unicode.cpp +62 -51
  206. package/src/llama.cpp/src/unicode.h +9 -10
  207. package/src/llama.cpp/tests/CMakeLists.txt +48 -37
  208. package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
  209. package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
  210. package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
  211. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  212. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  213. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  214. package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
  215. package/src/llama.cpp/tests/test-rope.cpp +61 -20
  216. package/src/llama.cpp/tests/test-sampling.cpp +2 -2
  217. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  218. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  219. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  220. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  221. package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
  222. package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
  223. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
  224. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  225. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
@@ -4,8 +4,11 @@
4
4
  #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
5
5
  #endif
6
6
 
7
+ #include "amx.h"
7
8
  #include "mmq.h"
8
9
  #include "ggml-impl.h"
10
+ #include "ggml-cpu-impl.h"
11
+ #include "ggml-cpu-quants.h"
9
12
  #include "ggml-quants.h"
10
13
  #include <algorithm>
11
14
  #include <type_traits>
@@ -15,10 +18,6 @@
15
18
  #include <unistd.h>
16
19
  #endif
17
20
 
18
- #if defined(_OPENMP)
19
- #include <omp.h>
20
- #endif
21
-
22
21
  #if (defined(_WIN32) || defined(_WIN64))
23
22
  #define RESTRICT __restrict
24
23
  #else
@@ -33,7 +32,7 @@
33
32
  #define ALWAYS_INLINE inline
34
33
  #endif
35
34
 
36
- #if defined(__AMX_INT8__)
35
+ #if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
37
36
 
38
37
  namespace {
39
38
 
@@ -496,13 +495,12 @@ inline void from_float(const float * x, char * vy, int64_t k);
496
495
 
497
496
  template <>
498
497
  inline void from_float<block_q8_0>(const float * x, char * vy, int64_t k) {
499
- // FIXME: using unoptimized reference impl until moved to CPU backend
500
- quantize_row_q8_0_ref(x, (block_q8_0 *)vy, k);
498
+ quantize_row_q8_0(x, (block_q8_0 *)vy, k);
501
499
  }
502
500
 
503
501
  template <>
504
502
  inline void from_float<block_q8_1>(const float * x, char * vy, int64_t k) {
505
- quantize_row_q8_1_ref(x, (block_q8_1 *)vy, k);
503
+ quantize_row_q8_1(x, (block_q8_1 *)vy, k);
506
504
  }
507
505
 
508
506
  template <>
@@ -950,7 +948,7 @@ template<typename TB, typename packed_B_t = packed_B_type<TB>>
950
948
  void unpack_B(packed_B_t * RESTRICT tile, const void * RESTRICT packed_B) {
951
949
  GGML_UNUSED(tile);
952
950
  GGML_UNUSED(packed_B);
953
- };
951
+ }
954
952
 
955
953
  template <>
956
954
  void unpack_B<block_q4_0>(int8_t * RESTRICT tile, const void * RESTRICT packed_B) {
@@ -1338,21 +1336,19 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
1338
1336
  __m512 vb[COLS];
1339
1337
  __m512 vc[ROWS * COLS];
1340
1338
 
1341
- auto loadc = [&](int idx) {
1339
+ auto loadc = [&](auto idx) {
1342
1340
  vc[idx] = _mm512_setzero_ps();
1343
1341
  };
1344
1342
  Unroll<ROWS * COLS>{}(loadc);
1345
1343
 
1346
- auto compute = [&](int idx, int k) {
1347
- // TODO: use `constexpr` here to get rid of interger div
1348
- // when upgraded to C++17
1349
- const int row = idx / COLS;
1350
- const int col = idx % COLS;
1344
+ auto compute = [&](auto idx, auto k) {
1345
+ constexpr int row = idx / COLS;
1346
+ constexpr int col = idx % COLS;
1351
1347
 
1352
- if (col == 0) {
1348
+ if constexpr (col == 0) {
1353
1349
  va = _mm512_loadu_ps(A + row * K + k);
1354
1350
  }
1355
- if (row == 0) {
1351
+ if constexpr (row == 0) {
1356
1352
  vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k)));
1357
1353
  }
1358
1354
  vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
@@ -1362,9 +1358,9 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
1362
1358
  Unroll<ROWS * COLS>{}(compute, k);
1363
1359
  }
1364
1360
 
1365
- auto storec = [&](int idx) {
1366
- const int row = idx / COLS;
1367
- const int col = idx % COLS;
1361
+ auto storec = [&](auto idx) {
1362
+ constexpr int row = idx / COLS;
1363
+ constexpr int col = idx % COLS;
1368
1364
  C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]);
1369
1365
  };
1370
1366
  Unroll<ROWS * COLS>{}(storec);
@@ -1382,13 +1378,13 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K>
1382
1378
  #define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size
1383
1379
 
1384
1380
  template<typename TB, int BLOCK_K>
1385
- void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K, int n_threads) {
1381
+ void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K) {
1386
1382
  const int NB = N / TILE_N;
1387
1383
  const int KB = K / BLOCK_K;
1388
1384
  const int TILE_SIZE = get_tile_size<TB>();
1389
1385
 
1390
1386
  // parallel on NB should be enough
1391
- parallel_for(n_threads, NB, [&](int begin, int end) {
1387
+ parallel_for(NB, [&](int begin, int end) {
1392
1388
  for (int n = begin; n < end; ++n) {
1393
1389
  for (int k = 0; k < KB; ++k) {
1394
1390
  int n0 = n * TILE_N;
@@ -1427,14 +1423,14 @@ struct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLO
1427
1423
  const __m512i off = _mm512_set1_epi8(8);
1428
1424
  const __m512i lowMask = _mm512_set1_epi8(0xF);
1429
1425
 
1430
- auto loadc = [&](int col) {
1426
+ auto loadc = [&](auto col) {
1431
1427
  vc[col] = _mm512_setzero_ps();
1432
1428
  };
1433
1429
  Unroll<COLS>{}(loadc);
1434
1430
 
1435
- auto compute = [&](int col, int i) {
1431
+ auto compute = [&](auto col, auto i) {
1436
1432
  // load a and compute compensation
1437
- if (col == 0) {
1433
+ if constexpr (col == 0) {
1438
1434
  const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
1439
1435
  vcomp = _mm512_setzero_si512();
1440
1436
  for (int k = 0; k < 8; ++k) {
@@ -1466,7 +1462,7 @@ struct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLO
1466
1462
  }
1467
1463
 
1468
1464
  //store to C
1469
- auto storec = [&](int col) {
1465
+ auto storec = [&](auto col) {
1470
1466
  _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1471
1467
  };
1472
1468
  Unroll<COLS>{}(storec);
@@ -1490,14 +1486,14 @@ struct tinygemm_kernel_vnni<block_q8_1, block_q4_1, float, 1, BLOCK_N, BLOCK_K>
1490
1486
 
1491
1487
  const __m512i lowMask = _mm512_set1_epi8(0xF);
1492
1488
 
1493
- auto loadc = [&](int col) {
1489
+ auto loadc = [&](auto col) {
1494
1490
  vc[col] = _mm512_setzero_ps();
1495
1491
  };
1496
1492
  Unroll<COLS>{}(loadc);
1497
1493
 
1498
- auto compute = [&](int col, int i) {
1494
+ auto compute = [&](auto col, auto i) {
1499
1495
  // load a
1500
- if (col == 0) {
1496
+ if constexpr (col == 0) {
1501
1497
  const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
1502
1498
  for (int k = 0; k < 8; ++k) {
1503
1499
  va[k] = _mm512_set1_epi32(a_ptr[k]);
@@ -1531,7 +1527,7 @@ struct tinygemm_kernel_vnni<block_q8_1, block_q4_1, float, 1, BLOCK_N, BLOCK_K>
1531
1527
  }
1532
1528
 
1533
1529
  //store to C
1534
- auto storec = [&](int col) {
1530
+ auto storec = [&](auto col) {
1535
1531
  _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1536
1532
  };
1537
1533
  Unroll<COLS>{}(storec);
@@ -1562,14 +1558,14 @@ struct tinygemm_kernel_vnni<block_q8_0, block_q8_0, float, BLOCK_M, BLOCK_N, BLO
1562
1558
  //
1563
1559
  const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
1564
1560
 
1565
- auto loadc = [&](int col) {
1561
+ auto loadc = [&](auto col) {
1566
1562
  vc[col] = _mm512_setzero_ps();
1567
1563
  };
1568
1564
  Unroll<COLS>{}(loadc);
1569
1565
 
1570
- auto compute = [&](int col, int i) {
1566
+ auto compute = [&](auto col, auto i) {
1571
1567
  // load a and add offset 128
1572
- if (col == 0) {
1568
+ if constexpr (col == 0) {
1573
1569
  const int32_t * a_ptr = reinterpret_cast<const int32_t *>(A[0 * KB + i].qs);
1574
1570
  for (int k = 0; k < 8; ++k) {
1575
1571
  va[k] = _mm512_set1_epi32(a_ptr[k]);
@@ -1602,7 +1598,7 @@ struct tinygemm_kernel_vnni<block_q8_0, block_q8_0, float, BLOCK_M, BLOCK_N, BLO
1602
1598
  }
1603
1599
 
1604
1600
  //store to C
1605
- auto storec = [&](int col) {
1601
+ auto storec = [&](auto col) {
1606
1602
  _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1607
1603
  };
1608
1604
  Unroll<COLS>{}(storec);
@@ -1634,7 +1630,7 @@ struct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLO
1634
1630
 
1635
1631
  const __m512i lowMask = _mm512_set1_epi8(0xF);
1636
1632
 
1637
- auto loadc = [&](int col) {
1633
+ auto loadc = [&](auto col) {
1638
1634
  vc[col] = _mm512_setzero_ps();
1639
1635
  };
1640
1636
  Unroll<COLS>{}(loadc);
@@ -1648,9 +1644,9 @@ struct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLO
1648
1644
  // int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8
1649
1645
  // from {16, 8} to {4, 32}
1650
1646
  //
1651
- auto compute = [&](int col, int i) {
1647
+ auto compute = [&](auto col, auto i) {
1652
1648
  // load a
1653
- if (col == 0) {
1649
+ if constexpr (col == 0) {
1654
1650
  for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1655
1651
  va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
1656
1652
  }
@@ -1702,7 +1698,7 @@ struct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLO
1702
1698
  }
1703
1699
 
1704
1700
  //store to C
1705
- auto storec = [&](int col) {
1701
+ auto storec = [&](auto col) {
1706
1702
  _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1707
1703
  };
1708
1704
  Unroll<COLS>{}(storec);
@@ -1735,15 +1731,15 @@ struct tinygemm_kernel_vnni<block_q8_K, block_q5_K, float, BLOCK_M, BLOCK_N, BLO
1735
1731
 
1736
1732
  const __m512i lowMask = _mm512_set1_epi8(0xF);
1737
1733
 
1738
- auto loadc = [&](int col) {
1734
+ auto loadc = [&](auto col) {
1739
1735
  vc[col] = _mm512_setzero_ps();
1740
1736
  };
1741
1737
  Unroll<COLS>{}(loadc);
1742
1738
 
1743
1739
  // Q5_K and Q4_K shares the same vnni formats, refer to notes above.
1744
- auto compute = [&](int col, int i) {
1740
+ auto compute = [&](auto col, auto i) {
1745
1741
  // load a
1746
- if (col == 0) {
1742
+ if constexpr (col == 0) {
1747
1743
  for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
1748
1744
  va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
1749
1745
  }
@@ -1808,7 +1804,7 @@ struct tinygemm_kernel_vnni<block_q8_K, block_q5_K, float, BLOCK_M, BLOCK_N, BLO
1808
1804
  }
1809
1805
 
1810
1806
  //store to C
1811
- auto storec = [&](int col) {
1807
+ auto storec = [&](auto col) {
1812
1808
  _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
1813
1809
  };
1814
1810
  Unroll<COLS>{}(storec);
@@ -1841,13 +1837,13 @@ struct tinygemm_kernel_vnni<block_q8_K, block_q6_K, float, BLOCK_M, BLOCK_N, BLO
1841
1837
  const __m512i m32s = _mm512_set1_epi32(32);
1842
1838
  const __m512i lowMask = _mm512_set1_epi8(0xF);
1843
1839
 
1844
- auto loadc = [&](int col) {
1840
+ auto loadc = [&](auto col) {
1845
1841
  vc[col] = _mm512_setzero_ps();
1846
1842
  };
1847
1843
  Unroll<COLS>{}(loadc);
1848
1844
 
1849
- auto compute = [&](int col, int i) {
1850
- if (col == 0) {
1845
+ auto compute = [&](auto col, auto i) {
1846
+ if constexpr (col == 0) {
1851
1847
  // load a
1852
1848
  va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0));
1853
1849
  va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64));
@@ -1959,13 +1955,13 @@ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, B
1959
1955
  const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
1960
1956
  const __m512i values256 = _mm512_add_epi8(values128, off);
1961
1957
 
1962
- auto loadc = [&](int col) {
1958
+ auto loadc = [&](auto col) {
1963
1959
  vc[col] = _mm512_setzero_ps();
1964
1960
  };
1965
1961
  Unroll<COLS>{}(loadc);
1966
1962
 
1967
- auto compute = [&](int col, int i) {
1968
- if (col == 0) {
1963
+ auto compute = [&](auto col, auto i) {
1964
+ if constexpr (col == 0) {
1969
1965
  // load a
1970
1966
  va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0));
1971
1967
  va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64));
@@ -2015,7 +2011,7 @@ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, B
2015
2011
  }
2016
2012
 
2017
2013
  //store to C
2018
- auto storec = [&](int col) {
2014
+ auto storec = [&](auto col) {
2019
2015
  _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
2020
2016
  };
2021
2017
  Unroll<COLS>{}(storec);
@@ -2327,25 +2323,39 @@ size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor) {
2327
2323
 
2328
2324
  // pack weight to vnni format
2329
2325
  void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2330
-
2331
- size_t alloc_size = ggml_backend_amx_get_alloc_size(tensor);
2332
- GGML_ASSERT(alloc_size == size);
2326
+ GGML_ASSERT(offset == 0 && size == ggml_nbytes(tensor)); // only full tensor conversion is supported for now
2333
2327
 
2334
2328
  const enum ggml_type TYPE = tensor->type;
2335
2329
 
2336
2330
  const int K = tensor->ne[0]; // ne0: in_features
2337
2331
  const int N = tensor->ne[1]; // ne1: out_features
2338
2332
 
2339
- #if defined(_OPENMP)
2340
- // the buffer ctx is not initialized when .set_tensor is called
2341
- int n_threads = omp_get_num_threads();
2342
- #else
2343
- int n_threads = 1;
2344
- #endif
2333
+ GGML_DISPATCH_QTYPES(TYPE, [&] {
2334
+ convert_B_packed_format<type, blck_size>((void *)((char *)tensor->data + offset), (const type *)data, N, K);
2335
+ });
2336
+ }
2337
+
2338
+ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
2339
+ struct ggml_tensor * src0 = dst->src[0];
2340
+
2341
+ const enum ggml_type TYPE = src0->type;
2342
+
2343
+ const bool is_floating_type = TYPE == GGML_TYPE_F16;
2344
+ if (is_floating_type) {
2345
+ return 0;
2346
+ }
2347
+
2348
+ const int M = dst->ne[1];
2349
+ const int K = src0->ne[0];
2350
+
2351
+ size_t desired_wsize = 0;
2345
2352
 
2346
2353
  GGML_DISPATCH_QTYPES(TYPE, [&] {
2347
- convert_B_packed_format<type, blck_size>((void *)((char *)tensor->data + offset), (const type *)data, N, K, n_threads);
2354
+ const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
2355
+ desired_wsize = M * row_size_A;
2348
2356
  });
2357
+
2358
+ return desired_wsize;
2349
2359
  }
2350
2360
 
2351
2361
  // NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX)
@@ -2356,14 +2366,12 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d
2356
2366
  //
2357
2367
  // the function performs: dst = src1 @ src0.T
2358
2368
  //
2359
- void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst) {
2369
+ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) {
2360
2370
  struct ggml_tensor * src0 = dst->src[0];
2361
2371
  struct ggml_tensor * src1 = dst->src[1];
2362
2372
 
2363
2373
  const enum ggml_type TYPE = src0->type;
2364
2374
 
2365
- const int n_threads = ctx->n_threads;
2366
-
2367
2375
  // f16 only has avx512 kernels for now,
2368
2376
  // amx kernels will be added once 6th gen xeon is released.
2369
2377
  const bool is_floating_type = TYPE == GGML_TYPE_F16;
@@ -2379,7 +2387,7 @@ void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor
2379
2387
  const int MB = div_up(M, BLOCK_M);
2380
2388
  const int NB = div_up(N, BLOCK_N);
2381
2389
 
2382
- parallel_for(n_threads, MB * NB, [&](int begin, int end) {
2390
+ parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
2383
2391
  GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
2384
2392
  for (int i = begin; i < end; ++i) {
2385
2393
  int mb = i / NB;
@@ -2412,27 +2420,29 @@ void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor
2412
2420
  }
2413
2421
 
2414
2422
  // pointer to work space, used convert A from float to quantized type
2415
- void * wdata = nullptr;
2423
+ void * wdata = params->wdata;
2416
2424
 
2417
2425
  //TODO: performance improvement: merge quant A
2418
- GGML_DISPATCH_QTYPES(TYPE, [&] {
2419
- const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
2420
- const size_t desired_wsize = M * row_size_A;
2421
- if (ctx->work_size < desired_wsize) {
2422
- ctx->work_data.reset(new char[desired_wsize]);
2423
- ctx->work_size = desired_wsize;
2424
- }
2425
- wdata = ctx->work_data.get();
2426
+ if (params->ith == 0) {
2427
+ GGML_DISPATCH_QTYPES(TYPE, [&] {
2428
+ const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
2429
+ const size_t desired_wsize = M * row_size_A;
2430
+ if (params->wsize < desired_wsize) {
2431
+ GGML_ABORT("insufficient work space size");
2432
+ }
2426
2433
 
2427
- // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size
2428
- // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
2429
- GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
2434
+ // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size
2435
+ // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
2436
+ GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
2430
2437
 
2431
- const float * A_data = static_cast<const float *>(src1->data);
2432
- for (int m = 0; m < M; ++m) {
2433
- from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
2434
- }
2435
- });
2438
+ const float * A_data = static_cast<const float *>(src1->data);
2439
+ for (int m = 0; m < M; ++m) {
2440
+ from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K);
2441
+ }
2442
+ });
2443
+ }
2444
+
2445
+ ggml_barrier(params->threadpool);
2436
2446
 
2437
2447
  if (M == 1) {
2438
2448
  // MB = 1 and handle 8 tiles in each block
@@ -2440,7 +2450,7 @@ void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor
2440
2450
  constexpr int BLOCK_N = TILE_N * kTilesN;
2441
2451
  const int NB = div_up(N, BLOCK_N);
2442
2452
 
2443
- parallel_for(n_threads, NB, [&](int begin, int end) {
2453
+ parallel_for_ggml(params, NB, [&](int begin, int end) {
2444
2454
  GGML_DISPATCH_QTYPES(TYPE, [&] {
2445
2455
  const int KB = K / blck_size;
2446
2456
  const int TILE_SIZE = get_tile_size<type>();
@@ -2470,7 +2480,7 @@ void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor
2470
2480
  const int MB = div_up(M, BLOCK_M);
2471
2481
  const int NB = div_up(N, BLOCK_N);
2472
2482
 
2473
- parallel_for(n_threads, MB * NB, [&](int begin, int end) {
2483
+ parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
2474
2484
  // init tile config for each thread
2475
2485
  ggml_tile_config_init();
2476
2486
 
@@ -2498,13 +2508,4 @@ void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor
2498
2508
  });
2499
2509
  }
2500
2510
 
2501
- #else // if defined(__AMX_INT8__)
2502
-
2503
- void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst) {
2504
- fprintf(stderr, "GGML is not compiled with AMX support!\n");
2505
-
2506
- GGML_UNUSED(ctx);
2507
- GGML_UNUSED(dst);
2508
- }
2509
-
2510
- #endif // if defined(__AMX_INT8__)
2511
+ #endif // if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
@@ -1,17 +1,10 @@
1
1
  #pragma once
2
2
  #include "common.h"
3
- #include <stdint.h>
4
3
 
5
- #ifdef __cplusplus
6
- extern "C" {
7
- #endif
4
+ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst);
8
5
 
9
6
  size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor);
10
7
 
11
8
  void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
12
9
 
13
- void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst);
14
-
15
- #ifdef __cplusplus
16
- }
17
- #endif
10
+ void ggml_backend_amx_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);