@fugood/llama.node 0.3.3 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/CMakeLists.txt +5 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/EmbeddingWorker.cpp +15 -5
  19. package/src/EmbeddingWorker.h +2 -1
  20. package/src/LlamaCompletionWorker.cpp +1 -1
  21. package/src/LlamaContext.cpp +81 -18
  22. package/src/LlamaContext.h +2 -0
  23. package/src/llama.cpp/.github/workflows/build.yml +197 -159
  24. package/src/llama.cpp/.github/workflows/docker.yml +5 -8
  25. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  27. package/src/llama.cpp/CMakeLists.txt +11 -6
  28. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  29. package/src/llama.cpp/cmake/common.cmake +33 -0
  30. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  31. package/src/llama.cpp/common/CMakeLists.txt +6 -2
  32. package/src/llama.cpp/common/arg.cpp +426 -245
  33. package/src/llama.cpp/common/common.cpp +143 -80
  34. package/src/llama.cpp/common/common.h +81 -24
  35. package/src/llama.cpp/common/sampling.cpp +53 -19
  36. package/src/llama.cpp/common/sampling.h +22 -1
  37. package/src/llama.cpp/common/speculative.cpp +274 -0
  38. package/src/llama.cpp/common/speculative.h +28 -0
  39. package/src/llama.cpp/docs/build.md +101 -148
  40. package/src/llama.cpp/examples/CMakeLists.txt +32 -13
  41. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +5 -4
  43. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  46. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  47. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  48. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  49. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  50. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  52. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  55. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  57. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  59. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
  61. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/infill/infill.cpp +1 -1
  63. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  64. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
  65. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  66. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  67. package/src/llama.cpp/examples/llava/clip.cpp +262 -66
  68. package/src/llama.cpp/examples/llava/clip.h +8 -2
  69. package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
  70. package/src/llama.cpp/examples/llava/llava.cpp +46 -19
  71. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
  72. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  73. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  75. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  76. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
  77. package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
  78. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/main/main.cpp +9 -5
  80. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  81. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  83. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  84. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  87. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  88. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  89. package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
  90. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  91. package/src/llama.cpp/examples/run/run.cpp +911 -0
  92. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
  94. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
  95. package/src/llama.cpp/examples/server/server.cpp +1758 -886
  96. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  97. package/src/llama.cpp/examples/server/utils.hpp +94 -304
  98. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  99. package/src/llama.cpp/examples/simple/simple.cpp +4 -0
  100. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
  101. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
  102. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
  104. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  106. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
  108. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  109. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  110. package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
  111. package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
  112. package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
  113. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  114. package/src/llama.cpp/ggml/include/ggml.h +106 -24
  115. package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
  116. package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
  117. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
  118. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
  119. package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
  120. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
  121. package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
  122. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
  123. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  124. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  125. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
  126. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  127. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  128. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  129. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  130. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  131. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  132. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  133. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  134. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  135. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
  136. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  137. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  138. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
  139. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
  140. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
  141. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  142. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
  143. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
  151. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
  152. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
  153. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  155. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
  156. package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
  157. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
  158. package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
  159. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
  160. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
  161. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
  162. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  163. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  164. package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
  165. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
  167. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
  169. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
  172. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  173. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  174. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
  175. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
  176. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  177. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
  178. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  179. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  180. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
  181. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
  182. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
  183. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  184. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  185. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
  187. package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
  188. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
  189. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
  190. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
  191. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
  192. package/src/llama.cpp/ggml/src/ggml.c +367 -207
  193. package/src/llama.cpp/include/llama-cpp.h +25 -0
  194. package/src/llama.cpp/include/llama.h +26 -19
  195. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  196. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  197. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  198. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  199. package/src/llama.cpp/src/CMakeLists.txt +2 -7
  200. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  201. package/src/llama.cpp/src/llama-grammar.h +2 -5
  202. package/src/llama.cpp/src/llama-sampling.cpp +35 -90
  203. package/src/llama.cpp/src/llama-vocab.cpp +6 -1
  204. package/src/llama.cpp/src/llama.cpp +1748 -640
  205. package/src/llama.cpp/src/unicode.cpp +62 -51
  206. package/src/llama.cpp/src/unicode.h +9 -10
  207. package/src/llama.cpp/tests/CMakeLists.txt +48 -37
  208. package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
  209. package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
  210. package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
  211. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  212. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  213. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  214. package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
  215. package/src/llama.cpp/tests/test-rope.cpp +61 -20
  216. package/src/llama.cpp/tests/test-sampling.cpp +2 -2
  217. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  218. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  219. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  220. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  221. package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
  222. package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
  223. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
  224. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  225. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
@@ -3,13 +3,14 @@
3
3
 
4
4
  #include "ggml-backend-impl.h"
5
5
  #include "ggml-backend.h"
6
- #include "ggml-cpu-aarch64.h"
6
+ #include "ggml-cpu-traits.h"
7
7
  #include "ggml-cpu-impl.h"
8
8
  #include "ggml-cpu.h"
9
9
  #include "ggml-impl.h"
10
10
  #include "ggml-quants.h"
11
11
  #include "ggml-cpu-quants.h"
12
12
  #include "ggml-threading.h"
13
+ #include "amx/amx.h"
13
14
  #include "ggml.h"
14
15
 
15
16
  #if defined(_MSC_VER) || defined(__MINGW32__)
@@ -109,10 +110,11 @@ static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
109
110
  #if defined(__ARM_ARCH)
110
111
  struct ggml_arm_arch_features_type {
111
112
  int has_neon;
113
+ int has_dotprod;
112
114
  int has_i8mm;
113
115
  int has_sve;
114
116
  int sve_cnt;
115
- } ggml_arm_arch_features = {-1, -1, -1, 0};
117
+ } ggml_arm_arch_features = {-1, -1, -1, -1, 0};
116
118
  #endif
117
119
 
118
120
 
@@ -124,8 +126,7 @@ struct ggml_arm_arch_features_type {
124
126
  #endif
125
127
  #include <windows.h>
126
128
 
127
-
128
- #if !defined(__clang__)
129
+ #if defined(_MSC_VER) && !defined(__clang__)
129
130
  #define GGML_CACHE_ALIGN __declspec(align(GGML_CACHE_LINE))
130
131
 
131
132
  typedef volatile LONG atomic_int;
@@ -222,10 +223,6 @@ typedef void * thread_ret_t;
222
223
 
223
224
  typedef pthread_t ggml_thread_t;
224
225
 
225
- #ifdef GGML_USE_CPU_HBM
226
- #include <hbwmalloc.h>
227
- #endif
228
-
229
226
  #if defined(__APPLE__)
230
227
  #include <unistd.h>
231
228
  #include <mach/mach.h>
@@ -299,7 +296,6 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
299
296
  },
300
297
  [GGML_TYPE_Q8_0] = {
301
298
  .from_float = quantize_row_q8_0,
302
- .from_float_to_mat = quantize_mat_q8_0,
303
299
  .vec_dot = ggml_vec_dot_q8_0_q8_0,
304
300
  .vec_dot_type = GGML_TYPE_Q8_0,
305
301
  #if defined (__ARM_FEATURE_MATMUL_INT8)
@@ -407,33 +403,6 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
407
403
  .vec_dot_type = GGML_TYPE_BF16,
408
404
  .nrows = 1,
409
405
  },
410
- [GGML_TYPE_Q4_0_4_4] = {
411
- .from_float = NULL,
412
- .vec_dot = NULL,
413
- .vec_dot_type = GGML_TYPE_Q8_0,
414
- .nrows = 1,
415
- .ncols = 4,
416
- .gemv = ggml_gemv_q4_0_4x4_q8_0,
417
- .gemm = ggml_gemm_q4_0_4x4_q8_0,
418
- },
419
- [GGML_TYPE_Q4_0_4_8] = {
420
- .from_float = NULL,
421
- .vec_dot = NULL,
422
- .vec_dot_type = GGML_TYPE_Q8_0,
423
- .nrows = 1,
424
- .ncols = 4,
425
- .gemv = ggml_gemv_q4_0_4x8_q8_0,
426
- .gemm = ggml_gemm_q4_0_4x8_q8_0,
427
- },
428
- [GGML_TYPE_Q4_0_8_8] = {
429
- .from_float = NULL,
430
- .vec_dot = NULL,
431
- .vec_dot_type = GGML_TYPE_Q8_0,
432
- .nrows = 1,
433
- .ncols = 8,
434
- .gemv = ggml_gemv_q4_0_8x8_q8_0,
435
- .gemm = ggml_gemm_q4_0_8x8_q8_0,
436
- },
437
406
  [GGML_TYPE_TQ1_0] = {
438
407
  .from_float = quantize_row_tq1_0,
439
408
  .vec_dot = ggml_vec_dot_tq1_0_q8_K,
@@ -485,21 +454,21 @@ const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type
485
454
  #define GGML_F32x4_ADD vaddq_f32
486
455
  #define GGML_F32x4_MUL vmulq_f32
487
456
  #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
488
- #define GGML_F32x4_REDUCE(res, x) \
489
- { \
490
- int offset = GGML_F32_ARR >> 1; \
491
- for (int i = 0; i < offset; ++i) { \
492
- (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
493
- } \
494
- offset >>= 1; \
495
- for (int i = 0; i < offset; ++i) { \
496
- (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
497
- } \
498
- offset >>= 1; \
499
- for (int i = 0; i < offset; ++i) { \
500
- (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
501
- } \
502
- (res) = GGML_F32x4_REDUCE_ONE((x)[0]); \
457
+ #define GGML_F32x4_REDUCE(res, x) \
458
+ { \
459
+ int offset = GGML_F32_ARR >> 1; \
460
+ for (int i = 0; i < offset; ++i) { \
461
+ (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
462
+ } \
463
+ offset >>= 1; \
464
+ for (int i = 0; i < offset; ++i) { \
465
+ (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
466
+ } \
467
+ offset >>= 1; \
468
+ for (int i = 0; i < offset; ++i) { \
469
+ (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \
470
+ } \
471
+ (res) = (ggml_float) GGML_F32x4_REDUCE_ONE((x)[0]); \
503
472
  }
504
473
 
505
474
  #define GGML_F32_VEC GGML_F32x4
@@ -614,7 +583,7 @@ do { \
614
583
  for (int i = 0; i < offset; ++i) { \
615
584
  x[i] = _mm512_add_ps(x[i], x[offset+i]); \
616
585
  } \
617
- res = _mm512_reduce_add_ps(x[0]); \
586
+ res = (ggml_float) _mm512_reduce_add_ps(x[0]); \
618
587
  } while (0)
619
588
 
620
589
  // TODO: is this optimal ?
@@ -664,7 +633,7 @@ do { \
664
633
  for (int i = 0; i < offset; ++i) { \
665
634
  x[i] = _mm512_add_ps(x[i], x[offset+i]); \
666
635
  } \
667
- res = _mm512_reduce_add_ps(x[0]); \
636
+ res = (ggml_float) _mm512_reduce_add_ps(x[0]); \
668
637
  } while (0)
669
638
 
670
639
  #define GGML_F16_VEC GGML_F32Cx16
@@ -675,8 +644,8 @@ do { \
675
644
  #define GGML_F16_VEC_FMA GGML_F32Cx16_FMA
676
645
  #define GGML_F16_VEC_ADD GGML_F32Cx16_ADD
677
646
  #define GGML_F16_VEC_MUL GGML_F32Cx16_MUL
678
- #define GGML_F16_VEC_REDUCE GGML_F32Cx16_REDUCE
679
647
 
648
+ #define GGML_F16_VEC_REDUCE GGML_F32Cx16_REDUCE
680
649
  #elif defined(__AVX__)
681
650
 
682
651
  #define GGML_SIMD
@@ -745,7 +714,7 @@ do { \
745
714
  #define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
746
715
  #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
747
716
  #else
748
- static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
717
+ static inline __m256 __avx_f32cx8_load(const ggml_fp16_t * x) {
749
718
  float tmp[8];
750
719
 
751
720
  for (int i = 0; i < 8; i++) {
@@ -1168,28 +1137,28 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
1168
1137
  #define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
1169
1138
  #define GGML_F32x4_ADD __lsx_vfadd_s
1170
1139
  #define GGML_F32x4_MUL __lsx_vfmul_s
1171
- #define GGML_F32x4_REDUCE(res, x) \
1172
- { \
1173
- int offset = GGML_F32_ARR >> 1; \
1174
- for (int i = 0; i < offset; ++i) { \
1175
- x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1176
- } \
1177
- offset >>= 1; \
1178
- for (int i = 0; i < offset; ++i) { \
1179
- x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1180
- } \
1181
- offset >>= 1; \
1182
- for (int i = 0; i < offset; ++i) { \
1183
- x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \
1184
- } \
1185
- __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \
1186
- tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \
1187
- tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1188
- const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
1189
- tmp = __lsx_vsrli_d((__m128i)t0, 32); \
1190
- tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \
1191
- tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1192
- res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1140
+ #define GGML_F32x4_REDUCE(res, x) \
1141
+ { \
1142
+ int offset = GGML_F32_ARR >> 1; \
1143
+ for (int i = 0; i < offset; ++i) { \
1144
+ x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
1145
+ } \
1146
+ offset >>= 1; \
1147
+ for (int i = 0; i < offset; ++i) { \
1148
+ x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
1149
+ } \
1150
+ offset >>= 1; \
1151
+ for (int i = 0; i < offset; ++i) { \
1152
+ x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \
1153
+ } \
1154
+ __m128i tmp = __lsx_vsrli_d((__m128i) x[0], 32); \
1155
+ tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]); \
1156
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1157
+ const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \
1158
+ tmp = __lsx_vsrli_d((__m128i) t0, 32); \
1159
+ tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, t0); \
1160
+ tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \
1161
+ res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
1193
1162
  }
1194
1163
 
1195
1164
  #define GGML_F32_VEC GGML_F32x4
@@ -1357,31 +1326,18 @@ struct ggml_compute_state {
1357
1326
  int ith;
1358
1327
  };
1359
1328
 
1360
- struct ggml_compute_params {
1361
- // ith = thread index, nth = number of threads
1362
- int ith, nth;
1363
-
1364
- // work buffer for all threads
1365
- size_t wsize;
1366
- void * wdata;
1367
-
1368
- struct ggml_threadpool * threadpool;
1369
- };
1370
-
1371
1329
  //
1372
1330
  // fundamental operations
1373
1331
  //
1374
1332
 
1375
1333
  inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1376
-
1377
1334
  inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1378
1335
 
1379
- inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1336
+ inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1337
+ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
1380
1338
 
1381
1339
  inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1382
-
1383
1340
  inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
1384
-
1385
1341
  inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
1386
1342
  inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
1387
1343
  inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
@@ -2276,7 +2232,7 @@ struct ggml_state {
2276
2232
 
2277
2233
  static struct ggml_state g_state = {0};
2278
2234
 
2279
- static void ggml_barrier(struct ggml_threadpool * tp) {
2235
+ void ggml_barrier(struct ggml_threadpool * tp) {
2280
2236
  int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed);
2281
2237
  if (n_threads == 1) {
2282
2238
  return;
@@ -2430,7 +2386,7 @@ bool ggml_is_numa(void) {
2430
2386
  #endif
2431
2387
 
2432
2388
  #if !defined(HWCAP2_I8MM)
2433
- #define HWCAP2_I8MM 0
2389
+ #define HWCAP2_I8MM (1 << 13)
2434
2390
  #endif
2435
2391
 
2436
2392
  static void ggml_init_arm_arch_features(void) {
@@ -2439,6 +2395,7 @@ static void ggml_init_arm_arch_features(void) {
2439
2395
  uint32_t hwcap2 = getauxval(AT_HWCAP2);
2440
2396
 
2441
2397
  ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
2398
+ ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
2442
2399
  ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
2443
2400
  ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
2444
2401
 
@@ -2453,6 +2410,11 @@ static void ggml_init_arm_arch_features(void) {
2453
2410
  }
2454
2411
  ggml_arm_arch_features.has_neon = oldp;
2455
2412
 
2413
+ if (sysctlbyname("hw.optional.arm.FEAT_DotProd", &oldp, &size, NULL, 0) != 0) {
2414
+ oldp = 0;
2415
+ }
2416
+ ggml_arm_arch_features.has_dotprod = oldp;
2417
+
2456
2418
  if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
2457
2419
  oldp = 0;
2458
2420
  }
@@ -4505,9 +4467,6 @@ static void ggml_compute_forward_add(
4505
4467
  case GGML_TYPE_IQ4_XS:
4506
4468
  case GGML_TYPE_IQ3_S:
4507
4469
  case GGML_TYPE_IQ2_S:
4508
- case GGML_TYPE_Q4_0_4_4:
4509
- case GGML_TYPE_Q4_0_4_8:
4510
- case GGML_TYPE_Q4_0_8_8:
4511
4470
  {
4512
4471
  ggml_compute_forward_add_q_f32(params, dst);
4513
4472
  } break;
@@ -4885,9 +4844,6 @@ static void ggml_compute_forward_add1(
4885
4844
  case GGML_TYPE_IQ4_XS:
4886
4845
  case GGML_TYPE_IQ3_S:
4887
4846
  case GGML_TYPE_IQ2_S:
4888
- case GGML_TYPE_Q4_0_4_4:
4889
- case GGML_TYPE_Q4_0_4_8:
4890
- case GGML_TYPE_Q4_0_8_8:
4891
4847
  {
4892
4848
  ggml_compute_forward_add1_q_f32(params, dst);
4893
4849
  } break;
@@ -5015,9 +4971,6 @@ static void ggml_compute_forward_acc(
5015
4971
  case GGML_TYPE_IQ4_XS:
5016
4972
  case GGML_TYPE_IQ3_S:
5017
4973
  case GGML_TYPE_IQ2_S:
5018
- case GGML_TYPE_Q4_0_4_4:
5019
- case GGML_TYPE_Q4_0_4_8:
5020
- case GGML_TYPE_Q4_0_8_8:
5021
4974
  default:
5022
4975
  {
5023
4976
  GGML_ABORT("fatal error");
@@ -7433,20 +7386,9 @@ static void ggml_compute_forward_mul_mat(
7433
7386
  const int ith = params->ith;
7434
7387
  const int nth = params->nth;
7435
7388
 
7436
- enum ggml_type type = src0->type;
7437
-
7438
- if (src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
7439
- type = (enum ggml_type)(intptr_t)src0->extra;
7440
- }
7441
-
7442
- enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
7389
+ enum ggml_type const vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
7443
7390
  ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
7444
- ggml_from_float_to_mat_t const from_float_to_mat = type_traits_cpu[vec_dot_type].from_float_to_mat;
7445
- int64_t const vec_dot_num_rows = type_traits_cpu[type].nrows;
7446
- int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
7447
- int64_t const blck_size_interleave = ggml_get_type_traits(type)->blck_size_interleave;
7448
- ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
7449
- ggml_gemm_t const gemm = type_traits_cpu[type].gemm;
7391
+ int64_t const vec_dot_num_rows = type_traits_cpu[src0->type].nrows;
7450
7392
 
7451
7393
  GGML_ASSERT(ne0 == ne01);
7452
7394
  GGML_ASSERT(ne1 == ne11);
@@ -7454,7 +7396,7 @@ static void ggml_compute_forward_mul_mat(
7454
7396
  GGML_ASSERT(ne3 == ne13);
7455
7397
 
7456
7398
  // we don't support permuted src0 or src1
7457
- GGML_ASSERT(nb00 == ggml_type_size(type));
7399
+ GGML_ASSERT(nb00 == ggml_type_size(src0->type));
7458
7400
  GGML_ASSERT(nb10 == ggml_type_size(src1->type));
7459
7401
 
7460
7402
  // dst cannot be transposed or permuted
@@ -7466,6 +7408,7 @@ static void ggml_compute_forward_mul_mat(
7466
7408
  // nb01 >= nb00 - src0 is not transposed
7467
7409
  // compute by src0 rows
7468
7410
 
7411
+ // TODO: extract to "extra_op"
7469
7412
  #if GGML_USE_LLAMAFILE
7470
7413
  // broadcast factors
7471
7414
  const int64_t r2 = ne12 / ne02;
@@ -7476,15 +7419,15 @@ static void ggml_compute_forward_mul_mat(
7476
7419
  if (src1_cont) {
7477
7420
  for (int64_t i13 = 0; i13 < ne13; i13++)
7478
7421
  for (int64_t i12 = 0; i12 < ne12; i12++)
7479
- if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
7422
+ if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
7480
7423
  (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
7481
- nb01/ggml_type_size(type),
7424
+ nb01/ggml_type_size(src0->type),
7482
7425
  (const char *)src1->data + i12*nb12 + i13*nb13,
7483
7426
  nb11/ggml_type_size(src1->type),
7484
7427
  (char *)dst->data + i12*nb2 + i13*nb3,
7485
7428
  nb1/ggml_type_size(dst->type),
7486
7429
  ith, nth,
7487
- type,
7430
+ src0->type,
7488
7431
  src1->type,
7489
7432
  dst->type))
7490
7433
  goto UseGgmlGemm1;
@@ -7505,19 +7448,10 @@ UseGgmlGemm1:;
7505
7448
 
7506
7449
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
7507
7450
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
7508
- int64_t i11_processed = 0;
7509
- if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
7510
- for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
7511
- from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
7512
- (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
7513
- 4, ne10, blck_size_interleave);
7514
- }
7515
- i11_processed = ne11 - ne11 % 4;
7516
- }
7517
- for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
7451
+ for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
7518
7452
  from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
7519
- (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
7520
- ne10);
7453
+ (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
7454
+ ne10);
7521
7455
  }
7522
7456
  }
7523
7457
  }
@@ -7537,15 +7471,15 @@ UseGgmlGemm1:;
7537
7471
 
7538
7472
  for (int64_t i13 = 0; i13 < ne13; i13++)
7539
7473
  for (int64_t i12 = 0; i12 < ne12; i12++)
7540
- if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
7474
+ if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
7541
7475
  (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
7542
- nb01/ggml_type_size(type),
7476
+ nb01/ggml_type_size(src0->type),
7543
7477
  (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
7544
7478
  row_size/ggml_type_size(vec_dot_type),
7545
7479
  (char *)dst->data + i12*nb2 + i13*nb3,
7546
7480
  nb1/ggml_type_size(dst->type),
7547
7481
  ith, nth,
7548
- type,
7482
+ src0->type,
7549
7483
  vec_dot_type,
7550
7484
  dst->type))
7551
7485
  goto UseGgmlGemm2;
@@ -7560,14 +7494,6 @@ UseGgmlGemm2:;
7560
7494
  // This is the size of the rest of the dimensions of the result
7561
7495
  const int64_t nr1 = ne1 * ne2 * ne3;
7562
7496
 
7563
- // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
7564
- int64_t num_rows_per_vec_dot = vec_dot_num_rows;
7565
- // TODO: currently the mmla kernels support only even numbered rows/cols.
7566
- // this check can be removed once they are extended to support odd numbered rows/cols too
7567
- if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
7568
- num_rows_per_vec_dot = 1;
7569
- }
7570
-
7571
7497
  // Now select a reasonable chunk size.
7572
7498
  int chunk_size = 16;
7573
7499
 
@@ -7595,28 +7521,6 @@ UseGgmlGemm2:;
7595
7521
  const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
7596
7522
  const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
7597
7523
 
7598
- if ((ggml_n_dims(src0) == 2) && gemv) {
7599
- const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
7600
- const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11;
7601
- int64_t src0_start = (ith * ne01) / nth;
7602
- int64_t src0_end = ((ith + 1) * ne01) / nth;
7603
- src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
7604
- src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
7605
- if (src0_start >= src0_end) return;
7606
-
7607
- // If there are more than three rows in src1, use gemm; otherwise, use gemv.
7608
- if (gemm && (ne11 > 3)) {
7609
- gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
7610
- (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
7611
- }
7612
- for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
7613
- gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
7614
- (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
7615
- src0_end - src0_start);
7616
- }
7617
- return;
7618
- }
7619
-
7620
7524
  // The first chunk comes from our thread_id, the rest will get auto-assigned.
7621
7525
  int current_chunk = ith;
7622
7526
 
@@ -7630,7 +7534,16 @@ UseGgmlGemm2:;
7630
7534
  const int64_t ir1_start = dr1 * ith1;
7631
7535
  const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
7632
7536
 
7633
- ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
7537
+ // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
7538
+ int64_t num_rows_per_vec_dot = vec_dot_num_rows;
7539
+
7540
+ // these checks are needed to avoid crossing dim1 boundaries
7541
+ // can be optimized, but the logic would become more complicated, so keeping it like this for simplicity
7542
+ if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
7543
+ num_rows_per_vec_dot = 1;
7544
+ }
7545
+
7546
+ ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
7634
7547
 
7635
7548
  if (nth >= nchunk0 * nchunk1) {
7636
7549
  break;
@@ -7662,8 +7575,6 @@ static void ggml_compute_forward_mul_mat_id(
7662
7575
  ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
7663
7576
  enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
7664
7577
  ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
7665
- int64_t const matmul_num_cols = type_traits_cpu[type].ncols;
7666
- ggml_gemv_t const gemv = type_traits_cpu[type].gemv;
7667
7578
 
7668
7579
  // we don't support permuted src0 or src1
7669
7580
  GGML_ASSERT(nb00 == ggml_type_size(type));
@@ -7749,34 +7660,6 @@ static void ggml_compute_forward_mul_mat_id(
7749
7660
  const int64_t nr0 = ne01; // src0 rows
7750
7661
  const int64_t nr1 = cne1; // src1 rows
7751
7662
 
7752
- if (((ggml_n_dims(src0) - 1) == 2) && gemv) {
7753
- int64_t src0_cur_start = (ith * ne01) / nth;
7754
- int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
7755
- src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
7756
- src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
7757
- if (src0_cur_start >= src0_cur_end) return;
7758
-
7759
- for (int ir1 = 0; ir1 < nr1; ir1++) {
7760
- struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
7761
- const int id = row_mapping.i1; // selected expert index
7762
-
7763
- const int64_t i11 = id % ne11;
7764
- const int64_t i12 = row_mapping.i2; // row index in src1
7765
-
7766
- const int64_t i1 = id; // selected expert index
7767
- const int64_t i2 = i12; // row
7768
-
7769
- const char * src1_col = (const char *) wdata +
7770
- (src1_cont || src1->type != vec_dot_type
7771
- ? (i11 + i12 * ne11) * row_size
7772
- : (i11 * nb11 + i12 * nb12));
7773
-
7774
- gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
7775
- (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
7776
- }
7777
- continue;
7778
- }
7779
-
7780
7663
  // distribute the thread work across the inner or outer loop based on which one is larger
7781
7664
 
7782
7665
  const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
@@ -8084,9 +7967,6 @@ static void ggml_compute_forward_out_prod(
8084
7967
  case GGML_TYPE_IQ4_XS:
8085
7968
  case GGML_TYPE_IQ3_S:
8086
7969
  case GGML_TYPE_IQ2_S:
8087
- case GGML_TYPE_Q4_0_4_4:
8088
- case GGML_TYPE_Q4_0_4_8:
8089
- case GGML_TYPE_Q4_0_8_8:
8090
7970
  {
8091
7971
  ggml_compute_forward_out_prod_q_f32(params, dst);
8092
7972
  } break;
@@ -8239,6 +8119,77 @@ static void ggml_compute_forward_set_f32(
8239
8119
  }
8240
8120
  }
8241
8121
 
8122
+ static void ggml_compute_forward_set_i32(
8123
+ const struct ggml_compute_params * params,
8124
+ struct ggml_tensor * dst) {
8125
+
8126
+ const struct ggml_tensor * src0 = dst->src[0];
8127
+ const struct ggml_tensor * src1 = dst->src[1];
8128
+
8129
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
8130
+ GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
8131
+
8132
+ // view src0 and dst with these strides and data offset inbytes during set
8133
+ // nb0 is implicitly element_size because src0 and dst are contiguous
8134
+ size_t nb1 = ((int32_t *) dst->op_params)[0];
8135
+ size_t nb2 = ((int32_t *) dst->op_params)[1];
8136
+ size_t nb3 = ((int32_t *) dst->op_params)[2];
8137
+ size_t offset = ((int32_t *) dst->op_params)[3];
8138
+ bool inplace = (bool) ((int32_t *) dst->op_params)[4];
8139
+
8140
+ if (!inplace) {
8141
+ if (params->ith == 0) {
8142
+ // memcpy needs to be synchronized across threads to avoid race conditions.
8143
+ // => do it in INIT phase
8144
+ memcpy(
8145
+ ((char *) dst->data),
8146
+ ((char *) src0->data),
8147
+ ggml_nbytes(dst));
8148
+ }
8149
+ ggml_barrier(params->threadpool);
8150
+ }
8151
+
8152
+ const int ith = params->ith;
8153
+ const int nth = params->nth;
8154
+
8155
+ const int nr = ggml_nrows(src1);
8156
+ const int nc = src1->ne[0];
8157
+
8158
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
8159
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
8160
+
8161
+ // src0 and dst as viewed during set
8162
+ const size_t nb0 = ggml_element_size(src0);
8163
+
8164
+ const int im0 = (ne10 == 0 ? 0 : ne10-1);
8165
+ const int im1 = (ne11 == 0 ? 0 : ne11-1);
8166
+ const int im2 = (ne12 == 0 ? 0 : ne12-1);
8167
+ const int im3 = (ne13 == 0 ? 0 : ne13-1);
8168
+
8169
+ GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
8170
+
8171
+ GGML_ASSERT(nb10 == sizeof(int32_t));
8172
+
8173
+ // rows per thread
8174
+ const int dr = (nr + nth - 1)/nth;
8175
+
8176
+ // row range for this thread
8177
+ const int ir0 = dr*ith;
8178
+ const int ir1 = MIN(ir0 + dr, nr);
8179
+
8180
+ for (int ir = ir0; ir < ir1; ++ir) {
8181
+ // src0 and dst are viewed with shape of src1 and offset
8182
+ // => same indices
8183
+ const int i3 = ir/(ne12*ne11);
8184
+ const int i2 = (ir - i3*ne12*ne11)/ne11;
8185
+ const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
8186
+
8187
+ ggml_vec_cpy_i32(nc,
8188
+ (int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
8189
+ (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
8190
+ }
8191
+ }
8192
+
8242
8193
  static void ggml_compute_forward_set(
8243
8194
  const struct ggml_compute_params * params,
8244
8195
  struct ggml_tensor * dst) {
@@ -8250,6 +8201,10 @@ static void ggml_compute_forward_set(
8250
8201
  {
8251
8202
  ggml_compute_forward_set_f32(params, dst);
8252
8203
  } break;
8204
+ case GGML_TYPE_I32:
8205
+ {
8206
+ ggml_compute_forward_set_i32(params, dst);
8207
+ } break;
8253
8208
  case GGML_TYPE_F16:
8254
8209
  case GGML_TYPE_BF16:
8255
8210
  case GGML_TYPE_Q4_0:
@@ -8274,9 +8229,6 @@ static void ggml_compute_forward_set(
8274
8229
  case GGML_TYPE_IQ4_XS:
8275
8230
  case GGML_TYPE_IQ3_S:
8276
8231
  case GGML_TYPE_IQ2_S:
8277
- case GGML_TYPE_Q4_0_4_4:
8278
- case GGML_TYPE_Q4_0_4_8:
8279
- case GGML_TYPE_Q4_0_8_8:
8280
8232
  default:
8281
8233
  {
8282
8234
  GGML_ABORT("fatal error");
@@ -8538,9 +8490,6 @@ static void ggml_compute_forward_get_rows(
8538
8490
  case GGML_TYPE_IQ4_XS:
8539
8491
  case GGML_TYPE_IQ3_S:
8540
8492
  case GGML_TYPE_IQ2_S:
8541
- case GGML_TYPE_Q4_0_4_4:
8542
- case GGML_TYPE_Q4_0_4_8:
8543
- case GGML_TYPE_Q4_0_8_8:
8544
8493
  {
8545
8494
  ggml_compute_forward_get_rows_q(params, dst);
8546
8495
  } break;
@@ -9130,9 +9079,6 @@ static void ggml_compute_forward_clamp(
9130
9079
  case GGML_TYPE_IQ3_S:
9131
9080
  case GGML_TYPE_IQ2_S:
9132
9081
  case GGML_TYPE_Q8_K:
9133
- case GGML_TYPE_Q4_0_4_4:
9134
- case GGML_TYPE_Q4_0_4_8:
9135
- case GGML_TYPE_Q4_0_8_8:
9136
9082
  case GGML_TYPE_I8:
9137
9083
  case GGML_TYPE_I16:
9138
9084
  case GGML_TYPE_I32:
@@ -9187,6 +9133,64 @@ static void ggml_rope_cache_init(
9187
9133
  }
9188
9134
  }
9189
9135
 
9136
+ static void ggml_mrope_cache_init(
9137
+ float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
9138
+ float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
9139
+ float * cache, float sin_sign, float theta_scale) {
9140
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
9141
+ float theta_t = theta_base_t;
9142
+ float theta_h = theta_base_h;
9143
+ float theta_w = theta_base_w;
9144
+ float theta_e = theta_base_e; // extra position id for vision encoder
9145
+ int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
9146
+ int sec_w = sections[1] + sections[0];
9147
+ int sec_e = sections[2] + sec_w;
9148
+ GGML_ASSERT(sect_dims <= ne0);
9149
+
9150
+ for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
9151
+ const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
9152
+
9153
+ int sector = (i0 / 2) % sect_dims;
9154
+ if (indep_sects) {
9155
+ // compute theta independently for each dim sections
9156
+ // (i.e. reset corresponding theta when `i0` go from one section to another)
9157
+ if (sector == 0) {
9158
+ theta_t = theta_base_t;
9159
+ }
9160
+ else if (sector == sections[0]) {
9161
+ theta_h = theta_base_h;;
9162
+ }
9163
+ else if (sector == sec_w) {
9164
+ theta_w = theta_base_w;
9165
+ }
9166
+ else if (sector == sec_e) {
9167
+ theta_e = theta_base_e;
9168
+ }
9169
+ }
9170
+
9171
+ float theta = theta_t;
9172
+ if (sector >= sections[0] && sector < sec_w) {
9173
+ theta = theta_h;
9174
+ }
9175
+ else if (sector >= sec_w && sector < sec_w + sections[2]) {
9176
+ theta = theta_w;
9177
+ }
9178
+ else if (sector >= sec_w + sections[2]) {
9179
+ theta = theta_e;
9180
+ }
9181
+
9182
+ rope_yarn(
9183
+ theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
9184
+ );
9185
+ cache[i0 + 1] *= sin_sign;
9186
+
9187
+ theta_t *= theta_scale;
9188
+ theta_w *= theta_scale;
9189
+ theta_h *= theta_scale;
9190
+ theta_e *= theta_scale;
9191
+ }
9192
+ }
9193
+
9190
9194
  static void ggml_compute_forward_rope_f32(
9191
9195
  const struct ggml_compute_params * params,
9192
9196
  struct ggml_tensor * dst,
@@ -9197,6 +9201,7 @@ static void ggml_compute_forward_rope_f32(
9197
9201
  const struct ggml_tensor * src2 = dst->src[2];
9198
9202
 
9199
9203
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
9204
+ int sections[4];
9200
9205
 
9201
9206
  //const int n_past = ((int32_t *) dst->op_params)[0];
9202
9207
  const int n_dims = ((int32_t *) dst->op_params)[1];
@@ -9210,6 +9215,7 @@ static void ggml_compute_forward_rope_f32(
9210
9215
  memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
9211
9216
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
9212
9217
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
9218
+ memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
9213
9219
 
9214
9220
  GGML_TENSOR_UNARY_OP_LOCALS
9215
9221
 
@@ -9242,6 +9248,16 @@ static void ggml_compute_forward_rope_f32(
9242
9248
  ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
9243
9249
 
9244
9250
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
9251
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
9252
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
9253
+
9254
+ if (is_mrope) {
9255
+ GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
9256
+ }
9257
+
9258
+ if (is_vision) {
9259
+ GGML_ASSERT(n_dims == ne0/2);
9260
+ }
9245
9261
 
9246
9262
  const float * freq_factors = NULL;
9247
9263
  if (src2 != NULL) {
@@ -9257,18 +9273,63 @@ static void ggml_compute_forward_rope_f32(
9257
9273
 
9258
9274
  const int32_t * pos = (const int32_t *) src1->data;
9259
9275
 
9260
- for (int64_t i3 = 0; i3 < ne3; i3++) {
9261
- for (int64_t i2 = 0; i2 < ne2; i2++) {
9262
- const int64_t p = pos[i2];
9276
+ for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
9277
+ for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
9263
9278
 
9264
9279
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
9265
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
9280
+ if (!is_mrope) {
9281
+ const int64_t p = pos[i2];
9282
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
9283
+ }
9284
+ else {
9285
+ const int64_t p_t = pos[i2];
9286
+ const int64_t p_h = pos[i2 + ne2];
9287
+ const int64_t p_w = pos[i2 + ne2 * 2];
9288
+ const int64_t p_e = pos[i2 + ne2 * 3];
9289
+ ggml_mrope_cache_init(
9290
+ p_t, p_h, p_w, p_e, sections, is_vision,
9291
+ freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
9292
+ }
9266
9293
 
9267
- for (int64_t i1 = 0; i1 < ne1; i1++) {
9294
+ for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
9268
9295
  if (ir++ < ir0) continue;
9269
9296
  if (ir > ir1) break;
9270
9297
 
9271
- if (!is_neox) {
9298
+ if (is_neox || is_mrope) {
9299
+ if (is_vision){
9300
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9301
+ const int64_t ic = i0/2;
9302
+
9303
+ const float cos_theta = cache[i0 + 0];
9304
+ const float sin_theta = cache[i0 + 1];
9305
+
9306
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
9307
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
9308
+
9309
+ const float x0 = src[0];
9310
+ const float x1 = src[n_dims];
9311
+
9312
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
9313
+ dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
9314
+ }
9315
+ } else {
9316
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9317
+ const int64_t ic = i0/2;
9318
+
9319
+ const float cos_theta = cache[i0 + 0];
9320
+ const float sin_theta = cache[i0 + 1];
9321
+
9322
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
9323
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
9324
+
9325
+ const float x0 = src[0];
9326
+ const float x1 = src[n_dims/2];
9327
+
9328
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
9329
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
9330
+ }
9331
+ }
9332
+ } else {
9272
9333
  for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9273
9334
  const float cos_theta = cache[i0 + 0];
9274
9335
  const float sin_theta = cache[i0 + 1];
@@ -9282,8 +9343,10 @@ static void ggml_compute_forward_rope_f32(
9282
9343
  dst_data[0] = x0*cos_theta - x1*sin_theta;
9283
9344
  dst_data[1] = x0*sin_theta + x1*cos_theta;
9284
9345
  }
9285
- } else {
9286
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9346
+ }
9347
+
9348
+ if (is_vision) {
9349
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
9287
9350
  const int64_t ic = i0/2;
9288
9351
 
9289
9352
  const float cos_theta = cache[i0 + 0];
@@ -9293,19 +9356,20 @@ static void ggml_compute_forward_rope_f32(
9293
9356
  float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
9294
9357
 
9295
9358
  const float x0 = src[0];
9296
- const float x1 = src[n_dims/2];
9359
+ const float x1 = src[n_dims];
9297
9360
 
9298
- dst_data[0] = x0*cos_theta - x1*sin_theta;
9299
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
9361
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
9362
+ dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
9300
9363
  }
9301
- }
9302
-
9303
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
9304
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
9305
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9364
+ } else {
9365
+ // fill the remain channels with data from src tensor
9366
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
9367
+ const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
9368
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9306
9369
 
9307
- dst_data[0] = src[0];
9308
- dst_data[1] = src[1];
9370
+ dst_data[0] = src[0];
9371
+ dst_data[1] = src[1];
9372
+ }
9309
9373
  }
9310
9374
  }
9311
9375
  }
@@ -9323,6 +9387,7 @@ static void ggml_compute_forward_rope_f16(
9323
9387
  const struct ggml_tensor * src2 = dst->src[2];
9324
9388
 
9325
9389
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
9390
+ int sections[4];
9326
9391
 
9327
9392
  //const int n_past = ((int32_t *) dst->op_params)[0];
9328
9393
  const int n_dims = ((int32_t *) dst->op_params)[1];
@@ -9335,6 +9400,8 @@ static void ggml_compute_forward_rope_f16(
9335
9400
  memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
9336
9401
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
9337
9402
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
9403
+ memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
9404
+
9338
9405
 
9339
9406
  GGML_TENSOR_UNARY_OP_LOCALS
9340
9407
 
@@ -9367,6 +9434,16 @@ static void ggml_compute_forward_rope_f16(
9367
9434
  ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
9368
9435
 
9369
9436
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
9437
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
9438
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
9439
+
9440
+ if (is_mrope) {
9441
+ GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
9442
+ }
9443
+
9444
+ if (is_vision) {
9445
+ GGML_ASSERT(n_dims == ne0/2);
9446
+ }
9370
9447
 
9371
9448
  const float * freq_factors = NULL;
9372
9449
  if (src2 != NULL) {
@@ -9384,16 +9461,61 @@ static void ggml_compute_forward_rope_f16(
9384
9461
 
9385
9462
  for (int64_t i3 = 0; i3 < ne3; i3++) {
9386
9463
  for (int64_t i2 = 0; i2 < ne2; i2++) {
9387
- const int64_t p = pos[i2];
9388
9464
 
9389
9465
  float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
9390
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
9466
+ if (!is_mrope) {
9467
+ const int64_t p = pos[i2];
9468
+ ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
9469
+ }
9470
+ else {
9471
+ const int64_t p_t = pos[i2];
9472
+ const int64_t p_h = pos[i2 + ne2];
9473
+ const int64_t p_w = pos[i2 + ne2 * 2];
9474
+ const int64_t p_e = pos[i2 + ne2 * 3];
9475
+ ggml_mrope_cache_init(
9476
+ p_t, p_h, p_w, p_e, sections, is_vision,
9477
+ freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
9478
+ }
9391
9479
 
9392
9480
  for (int64_t i1 = 0; i1 < ne1; i1++) {
9393
9481
  if (ir++ < ir0) continue;
9394
9482
  if (ir > ir1) break;
9395
9483
 
9396
- if (!is_neox) {
9484
+ if (is_neox || is_mrope) {
9485
+ if (is_vision) {
9486
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9487
+ const int64_t ic = i0/2;
9488
+
9489
+ const float cos_theta = cache[i0 + 0];
9490
+ const float sin_theta = cache[i0 + 1];
9491
+
9492
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
9493
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
9494
+
9495
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
9496
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
9497
+
9498
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9499
+ dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9500
+ }
9501
+ } else {
9502
+ for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9503
+ const int64_t ic = i0/2;
9504
+
9505
+ const float cos_theta = cache[i0 + 0];
9506
+ const float sin_theta = cache[i0 + 1];
9507
+
9508
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
9509
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
9510
+
9511
+ const float x0 = GGML_FP16_TO_FP32(src[0]);
9512
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
9513
+
9514
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9515
+ dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9516
+ }
9517
+ }
9518
+ } else {
9397
9519
  for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9398
9520
  const float cos_theta = cache[i0 + 0];
9399
9521
  const float sin_theta = cache[i0 + 1];
@@ -9407,8 +9529,10 @@ static void ggml_compute_forward_rope_f16(
9407
9529
  dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9408
9530
  dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9409
9531
  }
9410
- } else {
9411
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
9532
+ }
9533
+
9534
+ if (is_vision) {
9535
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
9412
9536
  const int64_t ic = i0/2;
9413
9537
 
9414
9538
  const float cos_theta = cache[i0 + 0];
@@ -9418,19 +9542,19 @@ static void ggml_compute_forward_rope_f16(
9418
9542
  ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
9419
9543
 
9420
9544
  const float x0 = GGML_FP16_TO_FP32(src[0]);
9421
- const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
9545
+ const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
9422
9546
 
9423
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9424
- dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9547
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
9548
+ dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
9425
9549
  }
9426
- }
9427
-
9428
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
9429
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
9430
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9550
+ } else {
9551
+ for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
9552
+ const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
9553
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
9431
9554
 
9432
- dst_data[0] = src[0];
9433
- dst_data[1] = src[1];
9555
+ dst_data[0] = src[0];
9556
+ dst_data[1] = src[1];
9557
+ }
9434
9558
  }
9435
9559
  }
9436
9560
  }
@@ -10429,6 +10553,40 @@ static void ggml_compute_forward_pad(
10429
10553
  }
10430
10554
  }
10431
10555
 
10556
+ // ggml_compute_forward_pad_reflect_1d
10557
+
10558
+ static void ggml_compute_forward_pad_reflect_1d(
10559
+ const struct ggml_compute_params * params,
10560
+ struct ggml_tensor * dst) {
10561
+
10562
+ const struct ggml_tensor * src0 = dst->src[0];
10563
+
10564
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
10565
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
10566
+
10567
+ const int ith = params->ith;
10568
+ const int nth = params->nth;
10569
+
10570
+ const int32_t * opts = (const int32_t *) dst->op_params;
10571
+ const int p0 = opts[0];
10572
+ const int p1 = opts[1];
10573
+
10574
+ GGML_TENSOR_UNARY_OP_LOCALS
10575
+
10576
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
10577
+ for (int64_t i2 = 0; i2 < ne2; i2++) {
10578
+ for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
10579
+ float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0);
10580
+ float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
10581
+
10582
+ ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
10583
+
10584
+ for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
10585
+ for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
10586
+ }
10587
+ }
10588
+ }
10589
+ }
10432
10590
 
10433
10591
  // ggml_compute_forward_arange
10434
10592
 
@@ -12304,6 +12462,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12304
12462
  return;
12305
12463
  }
12306
12464
 
12465
+ // extra_buffer op?
12466
+ if (ggml_cpu_extra_compute_forward(params, tensor)) return;
12467
+
12307
12468
  switch (tensor->op) {
12308
12469
  case GGML_OP_DUP:
12309
12470
  {
@@ -12525,6 +12686,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12525
12686
  {
12526
12687
  ggml_compute_forward_pad(params, tensor);
12527
12688
  } break;
12689
+ case GGML_OP_PAD_REFLECT_1D:
12690
+ {
12691
+ ggml_compute_forward_pad_reflect_1d(params, tensor);
12692
+ } break;
12528
12693
  case GGML_OP_ARANGE:
12529
12694
  {
12530
12695
  ggml_compute_forward_arange(params, tensor);
@@ -12867,6 +13032,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
12867
13032
  } break;
12868
13033
  case GGML_OP_UPSCALE:
12869
13034
  case GGML_OP_PAD:
13035
+ case GGML_OP_PAD_REFLECT_1D:
12870
13036
  case GGML_OP_ARANGE:
12871
13037
  case GGML_OP_TIMESTEP_EMBEDDING:
12872
13038
  case GGML_OP_ARGSORT:
@@ -12956,7 +13122,7 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data);
12956
13122
  #include "windows.h"
12957
13123
 
12958
13124
  // TODO: support > 64 CPUs
12959
- bool ggml_thread_apply_affinity(bool * mask) {
13125
+ static bool ggml_thread_apply_affinity(bool * mask) {
12960
13126
  HANDLE h = GetCurrentThread();
12961
13127
  uint64_t bitmask = 0ULL;
12962
13128
 
@@ -13246,140 +13412,142 @@ struct ggml_cplan ggml_graph_plan(
13246
13412
 
13247
13413
  size_t cur = 0;
13248
13414
 
13249
- switch (node->op) {
13250
- case GGML_OP_CPY:
13251
- case GGML_OP_DUP:
13252
- {
13253
- if (ggml_is_quantized(node->type) ||
13254
- // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
13255
- (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
13256
- (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
13415
+ if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
13416
+
13417
+ switch (node->op) {
13418
+ case GGML_OP_CPY:
13419
+ case GGML_OP_DUP:
13420
+ {
13421
+ if (ggml_is_quantized(node->type) ||
13422
+ // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
13423
+ (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
13424
+ (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
13425
+ cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
13426
+ }
13427
+ } break;
13428
+ case GGML_OP_ADD:
13429
+ case GGML_OP_ADD1:
13430
+ {
13431
+ if (ggml_is_quantized(node->src[0]->type)) {
13432
+ cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
13433
+ }
13434
+ } break;
13435
+ case GGML_OP_ACC:
13436
+ {
13437
+ if (ggml_is_quantized(node->src[0]->type)) {
13438
+ cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
13439
+ }
13440
+ } break;
13441
+ case GGML_OP_COUNT_EQUAL:
13442
+ {
13443
+ cur = ggml_type_size(node->type)*n_tasks;
13444
+ } break;
13445
+ case GGML_OP_MUL_MAT:
13446
+ {
13447
+ const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
13448
+
13449
+ if (node->src[1]->type != vec_dot_type) {
13450
+ cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
13451
+ }
13452
+ } break;
13453
+ case GGML_OP_MUL_MAT_ID:
13454
+ {
13455
+ cur = 0;
13456
+ const struct ggml_tensor * src0 = node->src[0];
13457
+ const struct ggml_tensor * src1 = node->src[1];
13458
+ const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
13459
+ if (src1->type != vec_dot_type) {
13460
+ cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
13461
+ }
13462
+ const int n_as = src0->ne[2];
13463
+ cur += GGML_PAD(cur, sizeof(int64_t)); // align
13464
+ cur += n_as * sizeof(int64_t); // matrix_row_counts
13465
+ cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
13466
+ } break;
13467
+ case GGML_OP_OUT_PROD:
13468
+ {
13469
+ if (ggml_is_quantized(node->src[0]->type)) {
13470
+ cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
13471
+ }
13472
+ } break;
13473
+ case GGML_OP_SOFT_MAX:
13474
+ case GGML_OP_ROPE:
13475
+ {
13257
13476
  cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
13258
- }
13259
- } break;
13260
- case GGML_OP_ADD:
13261
- case GGML_OP_ADD1:
13262
- {
13263
- if (ggml_is_quantized(node->src[0]->type)) {
13264
- cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
13265
- }
13266
- } break;
13267
- case GGML_OP_ACC:
13268
- {
13269
- if (ggml_is_quantized(node->src[0]->type)) {
13270
- cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
13271
- }
13272
- } break;
13273
- case GGML_OP_COUNT_EQUAL:
13274
- {
13275
- cur = ggml_type_size(node->type)*n_tasks;
13276
- } break;
13277
- case GGML_OP_MUL_MAT:
13278
- {
13279
- const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
13477
+ } break;
13478
+ case GGML_OP_CONV_TRANSPOSE_1D:
13479
+ {
13480
+ GGML_ASSERT(node->src[0]->ne[3] == 1);
13481
+ GGML_ASSERT(node->src[1]->ne[2] == 1);
13482
+ GGML_ASSERT(node->src[1]->ne[3] == 1);
13483
+
13484
+ const int64_t ne00 = node->src[0]->ne[0]; // K
13485
+ const int64_t ne01 = node->src[0]->ne[1]; // Cout
13486
+ const int64_t ne02 = node->src[0]->ne[2]; // Cin
13487
+ const int64_t ne10 = node->src[1]->ne[0]; // L
13488
+ const int64_t ne11 = node->src[1]->ne[1]; // Cin
13489
+
13490
+ if ((node->src[0]->type == GGML_TYPE_F16 ||
13491
+ node->src[0]->type == GGML_TYPE_BF16) &&
13492
+ node->src[1]->type == GGML_TYPE_F32) {
13493
+ cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
13494
+ cur += sizeof(ggml_fp16_t)*ne10*ne11;
13495
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
13496
+ node->src[1]->type == GGML_TYPE_F32) {
13497
+ cur += sizeof(float)*ne00*ne01*ne02;
13498
+ cur += sizeof(float)*ne10*ne11;
13499
+ } else {
13500
+ GGML_ABORT("fatal error");
13501
+ }
13502
+ } break;
13503
+ case GGML_OP_CONV_TRANSPOSE_2D:
13504
+ {
13505
+ const int64_t ne00 = node->src[0]->ne[0]; // W
13506
+ const int64_t ne01 = node->src[0]->ne[1]; // H
13507
+ const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
13508
+ const int64_t ne03 = node->src[0]->ne[3]; // Channels In
13280
13509
 
13281
- if (node->src[1]->type != vec_dot_type) {
13282
- cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
13283
- }
13284
- } break;
13285
- case GGML_OP_MUL_MAT_ID:
13286
- {
13287
- cur = 0;
13288
- const struct ggml_tensor * src0 = node->src[0];
13289
- const struct ggml_tensor * src1 = node->src[1];
13290
- const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
13291
- if (src1->type != vec_dot_type) {
13292
- cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
13293
- }
13294
- const int n_as = src0->ne[2];
13295
- cur += GGML_PAD(cur, sizeof(int64_t)); // align
13296
- cur += n_as * sizeof(int64_t); // matrix_row_counts
13297
- cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
13298
- } break;
13299
- case GGML_OP_OUT_PROD:
13300
- {
13301
- if (ggml_is_quantized(node->src[0]->type)) {
13302
- cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
13303
- }
13304
- } break;
13305
- case GGML_OP_SOFT_MAX:
13306
- case GGML_OP_ROPE:
13307
- {
13308
- cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
13309
- } break;
13310
- case GGML_OP_CONV_TRANSPOSE_1D:
13311
- {
13312
- GGML_ASSERT(node->src[0]->ne[3] == 1);
13313
- GGML_ASSERT(node->src[1]->ne[2] == 1);
13314
- GGML_ASSERT(node->src[1]->ne[3] == 1);
13315
-
13316
- const int64_t ne00 = node->src[0]->ne[0]; // K
13317
- const int64_t ne01 = node->src[0]->ne[1]; // Cout
13318
- const int64_t ne02 = node->src[0]->ne[2]; // Cin
13319
-
13320
- const int64_t ne10 = node->src[1]->ne[0]; // L
13321
- const int64_t ne11 = node->src[1]->ne[1]; // Cin
13322
-
13323
- if ((node->src[0]->type == GGML_TYPE_F16 ||
13324
- node->src[0]->type == GGML_TYPE_BF16) &&
13325
- node->src[1]->type == GGML_TYPE_F32) {
13326
- cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
13327
- cur += sizeof(ggml_fp16_t)*ne10*ne11;
13328
- } else if (node->src[0]->type == GGML_TYPE_F32 &&
13329
- node->src[1]->type == GGML_TYPE_F32) {
13330
- cur += sizeof(float)*ne00*ne01*ne02;
13331
- cur += sizeof(float)*ne10*ne11;
13332
- } else {
13333
- GGML_ABORT("fatal error");
13334
- }
13335
- } break;
13336
- case GGML_OP_CONV_TRANSPOSE_2D:
13337
- {
13338
- const int64_t ne00 = node->src[0]->ne[0]; // W
13339
- const int64_t ne01 = node->src[0]->ne[1]; // H
13340
- const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
13341
- const int64_t ne03 = node->src[0]->ne[3]; // Channels In
13342
-
13343
- const int64_t ne10 = node->src[1]->ne[0]; // W
13344
- const int64_t ne11 = node->src[1]->ne[1]; // H
13345
- const int64_t ne12 = node->src[1]->ne[2]; // Channels In
13346
-
13347
- cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
13348
- cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
13349
- } break;
13350
- case GGML_OP_FLASH_ATTN_EXT:
13351
- {
13352
- const int64_t ne00 = node->src[0]->ne[0]; // D
13510
+ const int64_t ne10 = node->src[1]->ne[0]; // W
13511
+ const int64_t ne11 = node->src[1]->ne[1]; // H
13512
+ const int64_t ne12 = node->src[1]->ne[2]; // Channels In
13353
13513
 
13354
- cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
13355
- } break;
13356
- case GGML_OP_FLASH_ATTN_BACK:
13357
- {
13358
- const int64_t D = node->src[0]->ne[0];
13359
- const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
13360
- const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
13361
- if (node->src[1]->type == GGML_TYPE_F32) {
13362
- cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
13363
- cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
13364
- } else if (node->src[1]->type == GGML_TYPE_F16) {
13365
- cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
13366
- cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
13367
- } else if (node->src[1]->type == GGML_TYPE_BF16) {
13368
- cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
13369
- cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
13370
- }
13371
- } break;
13514
+ cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
13515
+ cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
13516
+ } break;
13517
+ case GGML_OP_FLASH_ATTN_EXT:
13518
+ {
13519
+ const int64_t ne00 = node->src[0]->ne[0]; // D
13372
13520
 
13373
- case GGML_OP_CROSS_ENTROPY_LOSS:
13374
- {
13375
- cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
13376
- } break;
13377
- case GGML_OP_COUNT:
13378
- {
13379
- GGML_ABORT("fatal error");
13380
- }
13381
- default:
13382
- break;
13521
+ cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
13522
+ } break;
13523
+ case GGML_OP_FLASH_ATTN_BACK:
13524
+ {
13525
+ const int64_t D = node->src[0]->ne[0];
13526
+ const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
13527
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
13528
+ if (node->src[1]->type == GGML_TYPE_F32) {
13529
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
13530
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
13531
+ } else if (node->src[1]->type == GGML_TYPE_F16) {
13532
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
13533
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
13534
+ } else if (node->src[1]->type == GGML_TYPE_BF16) {
13535
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
13536
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
13537
+ }
13538
+ } break;
13539
+
13540
+ case GGML_OP_CROSS_ENTROPY_LOSS:
13541
+ {
13542
+ cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
13543
+ } break;
13544
+ case GGML_OP_COUNT:
13545
+ {
13546
+ GGML_ABORT("fatal error");
13547
+ }
13548
+ default:
13549
+ break;
13550
+ }
13383
13551
  }
13384
13552
 
13385
13553
  work_size = MAX(work_size, cur);
@@ -13578,29 +13746,6 @@ static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool, int
13578
13746
 
13579
13747
  #endif // GGML_USE_OPENMP
13580
13748
 
13581
- void ggml_threadpool_params_init(struct ggml_threadpool_params * p, int n_threads) {
13582
- p->n_threads = n_threads;
13583
- p->prio = 0; // default priority (usually means normal or inherited)
13584
- p->poll = 50; // hybrid-polling enabled
13585
- p->strict_cpu = false; // no strict placement (all threads share same cpumask)
13586
- p->paused = false; // threads are ready to go
13587
- memset(p->cpumask, 0, GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
13588
- }
13589
-
13590
- struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) {
13591
- struct ggml_threadpool_params p;
13592
- ggml_threadpool_params_init(&p, n_threads);
13593
- return p;
13594
- }
13595
-
13596
- bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) {
13597
- if (p0->n_threads != p1->n_threads ) return false;
13598
- if (p0->prio != p1->prio ) return false;
13599
- if (p0->poll != p1->poll ) return false;
13600
- if (p0->strict_cpu != p1->strict_cpu ) return false;
13601
- return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
13602
- }
13603
-
13604
13749
  static struct ggml_threadpool * ggml_threadpool_new_impl(
13605
13750
  struct ggml_threadpool_params * tpp,
13606
13751
  struct ggml_cgraph * cgraph,
@@ -13896,15 +14041,23 @@ int ggml_cpu_has_vsx(void) {
13896
14041
  }
13897
14042
 
13898
14043
  int ggml_cpu_has_neon(void) {
13899
- #if defined(__ARM_ARCH)
14044
+ #if defined(__ARM_ARCH) && defined(__ARM_NEON)
13900
14045
  return ggml_arm_arch_features.has_neon;
13901
14046
  #else
13902
14047
  return 0;
13903
14048
  #endif
13904
14049
  }
13905
14050
 
14051
+ int ggml_cpu_has_dotprod(void) {
14052
+ #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD)
14053
+ return ggml_arm_arch_features.has_dotprod;
14054
+ #else
14055
+ return 0;
14056
+ #endif
14057
+ }
14058
+
13906
14059
  int ggml_cpu_has_sve(void) {
13907
- #if defined(__ARM_ARCH)
14060
+ #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
13908
14061
  return ggml_arm_arch_features.has_sve;
13909
14062
  #else
13910
14063
  return 0;
@@ -13912,7 +14065,7 @@ int ggml_cpu_has_sve(void) {
13912
14065
  }
13913
14066
 
13914
14067
  int ggml_cpu_has_matmul_int8(void) {
13915
- #if defined(__ARM_ARCH)
14068
+ #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8)
13916
14069
  return ggml_arm_arch_features.has_i8mm;
13917
14070
  #else
13918
14071
  return 0;
@@ -13920,7 +14073,7 @@ int ggml_cpu_has_matmul_int8(void) {
13920
14073
  }
13921
14074
 
13922
14075
  int ggml_cpu_get_sve_cnt(void) {
13923
- #if defined(__ARM_ARCH)
14076
+ #if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
13924
14077
  return ggml_arm_arch_features.sve_cnt;
13925
14078
  #else
13926
14079
  return 0;